1#define(VARIANTS)
  2[
  3  {
  4    "SHADER_NAME": "soft_max_f32",
  5    "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"]
  6  },
  7  {
  8    "SHADER_NAME": "soft_max_f32_inplace",
  9    "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"]
 10  },
 11  {
 12    "SHADER_NAME": "soft_max_f32_sink",
 13    "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"]
 14  },
 15  {
 16    "SHADER_NAME": "soft_max_f32_sink_inplace",
 17    "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"]
 18  },
 19  {
 20    "SHADER_NAME": "soft_max_f32_mask_f32",
 21    "REPLS": {
 22      "MASK_TYPE" : "f32",
 23    },
 24    "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
 25  },
 26  {
 27    "SHADER_NAME": "soft_max_f32_mask_f32_inplace",
 28    "REPLS": {
 29      "MASK_TYPE" : "f32",
 30    },
 31    "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
 32  },
 33  {
 34    "SHADER_NAME": "soft_max_f32_mask_f16",
 35    "REPLS": {
 36      "MASK_TYPE" : "f16",
 37    },
 38    "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"]
 39  },
 40  {
 41    "SHADER_NAME": "soft_max_f32_mask_f16_inplace",
 42    "REPLS": {
 43      "MASK_TYPE" : "f16",
 44    },
 45    "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"]
 46  },
 47  {
 48    "SHADER_NAME": "soft_max_f32_mask_f32_sink",
 49    "REPLS": {
 50      "MASK_TYPE" : "f32",
 51    },
 52    "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
 53  },
 54  {
 55    "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace",
 56    "REPLS": {
 57      "MASK_TYPE" : "f32",
 58    },
 59    "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
 60  },
 61  {
 62    "SHADER_NAME": "soft_max_f32_mask_f16_sink",
 63    "REPLS": {
 64      "MASK_TYPE" : "f16",
 65    },
 66    "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"]
 67  },
 68  {
 69    "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace",
 70    "REPLS": {
 71      "MASK_TYPE" : "f16",
 72    },
 73    "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"]
 74  }
 75]
 76#end(VARIANTS)
 77
 78#define(DECLS)
 79
 80#decl(BASE_BINDINGS)
 81@group(0) @binding(1)
 82var<storage, read_write> dst: array<f32>;
 83
 84@group(0) @binding(2)
 85var<uniform> params: Params;
 86#enddecl(BASE_BINDINGS)
 87
 88#decl(BASE_BINDINGS_INPLACE)
 89@group(0) @binding(1)
 90var<uniform> params: Params;
 91#enddecl(BASE_BINDINGS_INPLACE)
 92
 93#decl(SINK_BINDINGS)
 94@group(0) @binding(1)
 95var<storage, read_write> sinks: array<f32>;
 96
 97@group(0) @binding(2)
 98var<storage, read_write> dst: array<f32>;
 99
100@group(0) @binding(3)
101var<uniform> params: Params;
102#enddecl(SINK_BINDINGS)
103
104#decl(SINK_BINDINGS_INPLACE)
105@group(0) @binding(1)
106var<storage, read_write> sinks: array<f32>;
107
108@group(0) @binding(2)
109var<uniform> params: Params;
110#enddecl(SINK_BINDINGS_INPLACE)
111
112#decl(MASK_BINDINGS)
113@group(0) @binding(1)
114var<storage, read_write> mask: array<{{MASK_TYPE}}>;
115
116@group(0) @binding(2)
117var<storage, read_write> dst: array<f32>;
118
119@group(0) @binding(3)
120var<uniform> params: Params;
121#enddecl(MASK_BINDINGS)
122
123#decl(MASK_BINDINGS_INPLACE)
124@group(0) @binding(1)
125var<storage, read_write> mask: array<{{MASK_TYPE}}>;
126
127@group(0) @binding(2)
128var<uniform> params: Params;
129#enddecl(MASK_BINDINGS_INPLACE)
130
131#decl(MASK_SINK_BINDINGS)
132@group(0) @binding(1)
133var<storage, read_write> mask: array<{{MASK_TYPE}}>;
134
135@group(0) @binding(2)
136var<storage, read_write> sinks: array<f32>;
137
138@group(0) @binding(3)
139var<storage, read_write> dst: array<f32>;
140
141@group(0) @binding(4)
142var<uniform> params: Params;
143#enddecl(MASK_SINK_BINDINGS)
144
145#decl(MASK_SINK_BINDINGS_INPLACE)
146@group(0) @binding(1)
147var<storage, read_write> mask: array<{{MASK_TYPE}}>;
148
149@group(0) @binding(2)
150var<storage, read_write> sinks: array<f32>;
151
152@group(0) @binding(3)
153var<uniform> params: Params;
154#enddecl(MASK_SINK_BINDINGS_INPLACE)
155
156#decl(NOT_INPLACE)
157fn inter_value(i: u32) -> f32 {
158    return dst[i];
159}
160
161fn update(i: u32, val: f32) {
162    dst[i] = val;
163}
164#enddecl(NOT_INPLACE)
165
166#decl(INPLACE)
167fn inter_value(i: u32) -> f32 {
168    return src[i];
169}
170
171fn update(i: u32, val: f32) {
172    src[i] = val;
173}
174#enddecl(INPLACE)
175
176#decl(NO_MASK)
177fn mask_val(i: u32) -> f32 {
178    return 0.0;
179}
180#enddecl(NO_MASK)
181
182#decl(MASK)
183fn mask_val(i: u32) -> f32 {
184    return f32(mask[i]);
185}
186#enddecl(MASK)
187
188#decl(NO_SINK)
189fn lower_max_bound(i2: u32) -> f32 {
190    return -1e30;
191}
192
193fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
194    return val;
195}
196#enddecl(NO_SINK)
197
198#decl(SINK)
199fn lower_max_bound(i2: u32) -> f32 {
200    return sinks[params.offset_sinks + i2];
201}
202
203fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 {
204    return val + exp(sinks[params.offset_sinks + i2] - max_val);
205}
206#enddecl(SINK)
207
208#end(DECLS)
209
210#define(SHADER)
211enable f16;
212
213struct Params {
214    offset_src0: u32,
215    offset_src1: u32,
216    offset_sinks: u32,
217    offset_dst: u32,
218
219    // Strides (in elements)
220    stride_src01: u32,
221    stride_src02: u32,
222    stride_src03: u32,
223
224    stride_src11: u32,
225    stride_src12: u32,
226    stride_src13: u32,
227
228    stride_dst1: u32,
229    stride_dst2: u32,
230    stride_dst3: u32,
231
232    // shape of src0/dst
233    ne: u32,
234    ne0: u32,
235    ne1: u32,
236    ne2: u32,
237
238    // shape of src1
239    ne12: u32,
240    ne13: u32,
241
242    scale: f32,
243    max_bias: f32,
244    n_head_log2: f32,
245    m0: f32,
246    m1: f32,
247};
248
249@group(0) @binding(0)
250var<storage, read_write> src: array<f32>;
251
252DECLS
253
254const CACHE_SIZE: u32 = 16;
255
256override wg_size: u32;
257var<workgroup> scratch: array<f32, wg_size>;
258
259@compute @workgroup_size(wg_size)
260fn main(@builtin(workgroup_id) wid: vec3<u32>,
261        @builtin(local_invocation_id) lid: vec3<u32>) {
262
263    var i = wid.x;
264    let i3 = i / (params.ne2 * params.ne1);
265    i = i % (params.ne2 * params.ne1);
266    let i2 = i / params.ne1;
267    let i1 = i % params.ne1;
268    let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
269    let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11;
270    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
271    let elems = (params.ne0 + wg_size - 1) / wg_size;
272
273    let head = f32(i2);
274    let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0);
275
276    var cache: array<f32, CACHE_SIZE>;
277
278    var max_val = lower_max_bound(i2);
279    var col = lid.x;
280    for (var j: u32 = 0; j < elems; j++) {
281        if (col >= params.ne0) {
282            break;
283        }
284        let val = src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col);
285        max_val = max(max_val, val);
286        if (col < CACHE_SIZE) {
287            cache[col] = val;
288        }
289        col += wg_size;
290    }
291
292    scratch[lid.x] = max_val;
293    workgroupBarrier();
294    var offset = wg_size / 2;
295    while (offset > 0) {
296        if (lid.x < offset) {
297            scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]);
298        }
299        offset = offset / 2;
300        workgroupBarrier();
301    }
302    let row_max = scratch[0];
303    workgroupBarrier();
304
305    var sum = 0.0f;
306    col = lid.x;
307    for (var j: u32 = 0; j < elems; j++) {
308        if (col >= params.ne0) {
309            break;
310        }
311        let val = select(src[i_src0_row + col] * params.scale + slope * mask_val(i_src1_row + col),
312                         cache[col], col < CACHE_SIZE);
313        let ex = exp(val - row_max);
314        sum += ex;
315        if (col < CACHE_SIZE) {
316            cache[col] = ex;
317        } else {
318            update(i_dst_row + col, ex);
319        }
320        col += wg_size;
321    }
322
323    scratch[lid.x] = sum;
324    workgroupBarrier();
325    offset = wg_size / 2;
326    while (offset > 0) {
327        if (lid.x < offset) {
328            scratch[lid.x] += scratch[lid.x + offset];
329        }
330        offset = offset / 2;
331        workgroupBarrier();
332    }
333    let row_sum = add_sinks(scratch[0], i2, row_max);
334
335    let sum_recip = 1.0 / row_sum;
336    col = lid.x;
337    for  (var j: u32 = 0; j < elems; j++) {
338        if (col >= params.ne0) {
339            break;
340        }
341        update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip);
342        col += wg_size;
343    }
344}
345#end(SHADER)