1#define(VARIANTS)
2
3[
4 {
5 "DECLS": ["NOT_INPLACE"]
6 },
7 {
8 "SHADER_SUFFIX": "inplace",
9 "DECLS": ["INPLACE"]
10 },
11]
12
13#end(VARIANTS)
14
15#define(DECLS)
16
17#decl(NOT_INPLACE)
18
19fn update(src_offset: u32, dst_offset: u32, scale: f32) {
20 dst[dst_offset] = scale * src[src_offset];
21}
22
23@group(0) @binding(1)
24var<storage, read_write> dst: array<f32>;
25
26@group(0) @binding(2)
27var<uniform> params: Params;
28
29#enddecl(NOT_INPLACE)
30
31#decl(INPLACE)
32
33fn update(src_offset: u32, dst_offset: u32, scale: f32) {
34 src[dst_offset] = scale * src[src_offset];
35}
36
37@group(0) @binding(1)
38var<uniform> params: Params;
39
40#enddecl(INPLACE)
41
42#end(DECLS)
43
44#define(SHADER)
45
46struct Params {
47 offset_src: u32, // in elements
48 offset_dst: u32, // in elements
49
50 // Strides (in elements)
51 stride_src1: u32,
52 stride_src2: u32,
53 stride_src3: u32,
54
55 stride_dst1: u32,
56 stride_dst2: u32,
57 stride_dst3: u32,
58
59 // Shape of src/dst
60 ne0: u32,
61 ne1: u32,
62 ne2: u32,
63 ne3: u32,
64
65 eps: f32
66};
67
68@group(0) @binding(0)
69var<storage, read_write> src: array<f32>;
70
71DECLS
72
73override wg_size: u32;
74var<workgroup> scratch: array<f32, wg_size>;
75
76@compute @workgroup_size(wg_size)
77fn main(@builtin(workgroup_id) wid: vec3<u32>,
78 @builtin(local_invocation_id) lid: vec3<u32>) {
79
80 // one thread per row
81 var i = wid.x;
82 let i3 = i / (params.ne2 * params.ne1);
83 i = i % (params.ne2 * params.ne1);
84 let i2 = i / params.ne1;
85 let i1 = i % params.ne1;
86 let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
87 let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
88
89 let elems = (params.ne0 + wg_size - 1) / wg_size;
90
91 var sum = 0.0f;
92 var col = lid.x;
93 for (var j: u32 = 0; j < elems; j++) {
94 if (col >= params.ne0) {
95 break;
96 }
97 sum += pow(src[i_src_row + col], 2.0);
98 col += wg_size;
99 }
100
101 scratch[lid.x] = sum;
102 workgroupBarrier();
103 var offset = wg_size / 2;
104 while (offset > 0) {
105 if (lid.x < offset) {
106 scratch[lid.x] += scratch[lid.x + offset];
107 }
108 offset = offset / 2;
109 workgroupBarrier();
110 }
111 sum = scratch[0];
112
113 let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
114 col = lid.x;
115 for (var j: u32 = 0; j < elems; j++) {
116 if (col >= params.ne0) {
117 break;
118 }
119 update(i_src_row + col, i_dst_row + col, scale);
120 col += wg_size;
121 }
122}
123#end(SHADER)