1#define(VARIANTS)
  2
  3[
  4  {
  5    "SHADER_NAME": "reglu_f32",
  6    "REPLS": {
  7      "TYPE" : "f32",
  8    },
  9    "DECLS": ["NO_SPLIT", "REGLU"]
 10  },
 11  {
 12    "SHADER_NAME": "reglu_f32_split",
 13    "REPLS": {
 14      "TYPE" : "f32",
 15    },
 16    "DECLS": ["SPLIT", "REGLU"]
 17  },
 18  {
 19    "SHADER_NAME": "reglu_f16",
 20    "REPLS": {
 21      "TYPE" : "f16",
 22    },
 23    "DECLS": ["NO_SPLIT", "REGLU"]
 24  },
 25  {
 26    "SHADER_NAME": "reglu_f16_split",
 27    "REPLS": {
 28      "TYPE" : "f16",
 29    },
 30    "DECLS": ["SPLIT", "REGLU"]
 31  },
 32  {
 33    "SHADER_NAME": "geglu_f32",
 34    "REPLS": {
 35      "TYPE" : "f32",
 36    },
 37    "DECLS": ["NO_SPLIT", "GEGLU"]
 38  },
 39  {
 40    "SHADER_NAME": "geglu_f32_split",
 41    "REPLS": {
 42      "TYPE" : "f32",
 43    },
 44    "DECLS": ["SPLIT", "GEGLU"]
 45  },
 46  {
 47    "SHADER_NAME": "geglu_f16",
 48    "REPLS": {
 49      "TYPE" : "f16",
 50    },
 51    "DECLS": ["NO_SPLIT", "GEGLU"]
 52  },
 53  {
 54    "SHADER_NAME": "geglu_f16_split",
 55    "REPLS": {
 56      "TYPE" : "f16",
 57    },
 58    "DECLS": ["SPLIT", "GEGLU"]
 59  },
 60  {
 61    "SHADER_NAME": "swiglu_f32",
 62    "REPLS": {
 63      "TYPE" : "f32",
 64    },
 65    "DECLS": ["NO_SPLIT", "SWIGLU"]
 66  },
 67  {
 68    "SHADER_NAME": "swiglu_f32_split",
 69    "REPLS": {
 70      "TYPE" : "f32",
 71    },
 72    "DECLS": ["SPLIT", "SWIGLU"]
 73  },
 74  {
 75    "SHADER_NAME": "swiglu_f16",
 76    "REPLS": {
 77      "TYPE" : "f16",
 78    },
 79    "DECLS": ["NO_SPLIT", "SWIGLU"]
 80  },
 81  {
 82    "SHADER_NAME": "swiglu_f16_split",
 83    "REPLS": {
 84      "TYPE" : "f16",
 85    },
 86    "DECLS": ["SPLIT", "SWIGLU"]
 87  },
 88  {
 89    "SHADER_NAME": "swiglu_oai_f32",
 90    "REPLS": {
 91      "TYPE" : "f32",
 92    },
 93    "DECLS": ["NO_SPLIT", "SWIGLU_OAI"]
 94  },
 95  {
 96    "SHADER_NAME": "swiglu_oai_f32_split",
 97    "REPLS": {
 98      "TYPE" : "f32",
 99    },
100    "DECLS": ["SPLIT", "SWIGLU_OAI"]
101  },
102  {
103    "SHADER_NAME": "geglu_erf_f32",
104    "REPLS": {
105      "TYPE" : "f32",
106    },
107    "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
108  },
109  {
110    "SHADER_NAME": "geglu_erf_f32_split",
111    "REPLS": {
112      "TYPE" : "f32",
113    },
114    "DECLS": ["SPLIT", "GEGLU_ERF"]
115  },
116  {
117    "SHADER_NAME": "geglu_erf_f16",
118    "REPLS": {
119      "TYPE" : "f16",
120    },
121    "DECLS": ["NO_SPLIT", "GEGLU_ERF"]
122  },
123  {
124    "SHADER_NAME": "geglu_erf_f16_split",
125    "REPLS": {
126      "TYPE" : "f16",
127    },
128    "DECLS": ["SPLIT", "GEGLU_ERF"]
129  },
130  {
131    "SHADER_NAME": "geglu_quick_f32",
132    "REPLS": {
133      "TYPE" : "f32",
134    },
135    "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
136  },
137  {
138    "SHADER_NAME": "geglu_quick_f32_split",
139    "REPLS": {
140      "TYPE" : "f32",
141    },
142    "DECLS": ["SPLIT", "GEGLU_QUICK"]
143  },
144  {
145    "SHADER_NAME": "geglu_quick_f16",
146    "REPLS": {
147      "TYPE" : "f16",
148    },
149    "DECLS": ["NO_SPLIT", "GEGLU_QUICK"]
150  },
151  {
152    "SHADER_NAME": "geglu_quick_f16_split",
153    "REPLS": {
154      "TYPE" : "f16",
155    },
156    "DECLS": ["SPLIT", "GEGLU_QUICK"]
157  },
158]
159
160#end(VARIANTS)
161
162#define(DECLS)
163
164#decl(REGLU)
165fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
166    return max(a, 0) * b;
167}
168#enddecl(REGLU)
169
170#decl(GEGLU)
171const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876;
172const GELU_COEF_A: {{TYPE}} = 0.044715;
173
174fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
175    let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a);
176    return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b;
177}
178#enddecl(GEGLU)
179
180#decl(SWIGLU)
181fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
182    return a / (1.0 + exp(-a)) * b;
183}
184#enddecl(SWIGLU)
185
186#decl(SWIGLU_OAI)
187fn op(a: f32, b: f32) -> f32 {
188  let xi = min(a, params.limit);
189  let gi = max(min(b, params.limit), -params.limit);
190  var out_glu = xi / (1.0 + exp(-xi * params.alpha));
191  out_glu = out_glu * (1.0 + gi);
192  return out_glu;
193}
194#enddecl(SWIGLU_OAI)
195
196#decl(GEGLU_ERF)
197const p_erf: {{TYPE}} = 0.3275911;
198const a1_erf: {{TYPE}} = 0.254829592;
199const a2_erf: {{TYPE}} = -0.284496736;
200const a3_erf: {{TYPE}} = 1.421413741;
201const a4_erf: {{TYPE}} = -1.453152027;
202const a5_erf: {{TYPE}} = 1.061405429;
203const SQRT_2_INV: {{TYPE}} = 0.7071067811865476;
204
205fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
206  let a_div_sqr2 = a * SQRT_2_INV;
207  let sign_x = sign(a_div_sqr2);
208  let x = abs(a_div_sqr2);
209  let t = 1.0 / (1.0 + p_erf * x);
210  let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x));
211  let erf_approx = sign_x * y;
212  return 0.5 * a * (1.0 + erf_approx) * b;
213}
214#enddecl(GEGLU_ERF)
215
216#decl(GEGLU_QUICK)
217const GELU_QUICK_COEF: {{TYPE}} = -1.702;
218
219fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} {
220    return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b;
221}
222#enddecl(GEGLU_QUICK)
223
224#decl(NO_SPLIT)
225@group(0) @binding(1)
226var<storage, read_write> dst: array<{{TYPE}}>;
227
228@group(0) @binding(2)
229var<uniform> params: Params;
230
231fn a_value(base: u32) -> {{TYPE}} {
232    let offset: u32 = select(0, params.ne0, params.swapped != 0);
233    return src0[base + offset];
234}
235
236fn b_value(base: u32) -> {{TYPE}} {
237    let offset: u32 = select(params.ne0, 0, params.swapped != 0);
238    return src0[base + offset];
239}
240#enddecl(NO_SPLIT)
241
242#decl(SPLIT)
243@group(0) @binding(1)
244var<storage, read_write> src1: array<{{TYPE}}>;
245
246@group(0) @binding(2)
247var<storage, read_write> dst: array<{{TYPE}}>;
248
249@group(0) @binding(3)
250var<uniform> params: Params;
251
252fn a_value(base: u32) -> {{TYPE}} {
253    return src0[base];
254}
255
256fn b_value(base: u32) -> {{TYPE}} {
257    return src1[base];
258}
259#enddecl(SPLIT)
260
261#end(DECLS)
262
263#define(SHADER)
264
265enable f16;
266
267struct Params {
268    offset_src0: u32,
269    offset_src1: u32,
270    offset_dst: u32,
271
272    // Strides (in elements)
273    stride_src01: u32,
274    stride_src02: u32,
275    stride_src03: u32,
276
277    stride_src11: u32,
278    stride_src12: u32,
279    stride_src13: u32,
280
281    stride_dst1: u32,
282    stride_dst2: u32,
283    stride_dst3: u32,
284
285    // shape of dst
286    ne: u32,
287    ne0: u32,
288    ne1: u32,
289    ne2: u32,
290
291    swapped: u32,
292    alpha: f32,
293    limit: f32,
294}
295
296@group(0) @binding(0)
297var<storage, read_write> src0: array<{{TYPE}}>;
298
299DECLS
300
301override wg_size: u32;
302@compute @workgroup_size(wg_size)
303fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
304    if (gid.x >= params.ne) {
305        return;
306    }
307
308    var i = gid.x;
309    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
310    i = i % (params.ne2 * params.ne1 * params.ne0);
311    let i2 = i / (params.ne1 * params.ne0);
312    i = i % (params.ne1 * params.ne0);
313    let i1 = i / params.ne0;
314    let i0 = i % params.ne0;
315
316    let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0;
317    let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0;
318    let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0;
319
320    dst[i_dst] = op(a_value(i_a), b_value(i_b));
321}
322
323#end(SHADER)