1#ifdef TYPE_F16
  2enable f16;
  3#define TYPE f16
  4#else
  5#define TYPE f32
  6#endif
  7
  8
  9@group(0) @binding(0)
 10var<storage, read_write> src: array<TYPE>;
 11
 12#ifndef INPLACE
 13@group(0) @binding(1)
 14var<storage, read_write> dst: array<TYPE>;
 15#define PARAMS_BINDING 2
 16#else
 17#define PARAMS_BINDING 1
 18#endif
 19
 20struct Params {
 21    ne: u32,            // total number of elements
 22    offset_src: u32,    // in elements
 23    offset_dst: u32,    // in elements
 24
 25    // Strides (in elements)
 26    stride_src0: u32,
 27    stride_src1: u32,
 28    stride_src2: u32,
 29    stride_src3: u32,
 30
 31    // Logical shapes
 32    ne0: u32,
 33    ne1: u32,
 34    ne2: u32,
 35#ifdef CLAMP
 36    clamp_min: f32,
 37    clamp_max: f32,
 38#endif
 39#ifdef FILL
 40    fill_val: f32,
 41#endif
 42#ifdef XIELU
 43    alpha_n: f32,
 44    alpha_p: f32,
 45    beta: f32,
 46    eps: f32,
 47#endif
 48
 49};
 50
 51@group(0) @binding(PARAMS_BINDING)
 52var<uniform> params: Params;
 53
 54@compute @workgroup_size(WG_SIZE)
 55fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
 56    if (gid.x >= params.ne) {
 57      return;
 58    }
 59    var i = gid.x;
 60    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
 61    i = i % (params.ne2 * params.ne1 * params.ne0);
 62    let i2 = i / (params.ne1 * params.ne0);
 63    i = i % (params.ne1 * params.ne0);
 64    let i1 = i / params.ne0;
 65    let i0 = i % params.ne0;
 66
 67    let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 +
 68                  i2 * params.stride_src2 + i3 * params.stride_src3;
 69
 70#ifdef ABS
 71    let res = abs(src[params.offset_src + src_idx]);
 72#endif
 73#ifdef SGN
 74    let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0),
 75                     src[params.offset_src + src_idx] > 0.0);
 76#endif
 77#ifdef NEG
 78    let res = -src[params.offset_src + src_idx];
 79#endif
 80#ifdef STEP
 81    let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0));
 82#endif
 83#ifdef TANH
 84    let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913));
 85#endif
 86#ifdef RELU
 87    let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0);
 88#endif
 89#ifdef ELU
 90    let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx],
 91                     src[params.offset_src + src_idx] > 0.0);
 92#endif
 93#ifdef HARDSIGMOID
 94    let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
 95#endif
 96#ifdef SIGMOID
 97    let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx]));
 98#endif
 99#ifdef SILU
100    let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx]));
101#endif
102#ifdef EXP
103    let res = exp(src[params.offset_src + src_idx]);
104#endif
105#ifdef LOG
106    let res = TYPE(log(f32(src[params.offset_src + src_idx])));
107#endif
108#ifdef CLAMP
109    let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max));
110#endif
111#ifdef FILL
112    let res = TYPE(params.fill_val);
113#endif
114#ifdef HARDSWISH
115    let res = src[params.offset_src + src_idx] *
116              min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0));
117#endif
118#ifdef GELU
119    let res = 0.5 * src[params.offset_src + src_idx] *
120              (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) *
121                               (src[params.offset_src + src_idx] +
122                                0.044715 * pow(src[params.offset_src + src_idx], 3.0)),
123                               -9.010913, 9.010913)));
124#endif
125#ifdef GELU_QUICK
126    let res = src[params.offset_src + src_idx] * 0.5 *
127              (1.0 + tanh(clamp(0.79788456 *
128                               (src[params.offset_src + src_idx] +
129                                0.044715 * src[params.offset_src + src_idx] *
130                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
131                               -9.010913, 9.010913)));
132#endif
133#ifdef GELU_ERF
134    let res = 0.5 * src[params.offset_src + src_idx] *
135              (1.0 + tanh(clamp(0.79788456 *
136                               (src[params.offset_src + src_idx] +
137                                0.044715 * src[params.offset_src + src_idx] *
138                                    src[params.offset_src + src_idx] * src[params.offset_src + src_idx]),
139                               -9.010913, 9.010913)));
140#endif
141#ifdef XIELU
142    let res =
143        select(((exp(min(src[params.offset_src + src_idx], TYPE(params.eps))) - 1.0) -
144                src[params.offset_src + src_idx]) *
145                   TYPE(params.alpha_n) +
146               TYPE(params.beta) * src[params.offset_src + src_idx],
147               TYPE(params.alpha_p) * src[params.offset_src + src_idx] *
148                   src[params.offset_src + src_idx] +
149                   TYPE(params.beta) * src[params.offset_src + src_idx],
150               src[params.offset_src + src_idx] > 0.0);
151#endif
152#ifdef SOFTPLUS
153    let src_f32 = f32(src[params.offset_src + src_idx]);
154    let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0));
155#endif
156#ifdef EXPM1
157    let res = exp(src[params.offset_src + src_idx]) - 1.0;
158#endif
159#ifdef FLOOR
160    let res = floor(src[params.offset_src + src_idx]);
161#endif
162#ifdef CEIL
163    let res = ceil(src[params.offset_src + src_idx]);
164#endif
165#ifdef ROUND
166    let src_f32 = f32(src[params.offset_src + src_idx]);
167    let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0);
168    let res = TYPE(result);
169#endif
170#ifdef TRUNC
171    let res = trunc(src[params.offset_src + src_idx]);
172#endif
173
174#ifdef INPLACE
175    src[params.offset_src + src_idx] = res;
176#else
177    dst[params.offset_dst + gid.x] = res;
178#endif
179}