1#define(VARIANTS)
  2
  3[
  4  {
  5    "REPLS": {
  6      "TYPE" : "f32",
  7    },
  8    "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
  9  },
 10  {
 11    "SHADER_SUFFIX": "f32_inplace",
 12    "REPLS": {
 13      "TYPE" : "f32",
 14    },
 15    "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
 16  },
 17  {
 18    "REPLS": {
 19      "TYPE" : "f16",
 20    },
 21    "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
 22  },
 23  {
 24    "SHADER_SUFFIX": "f16_inplace",
 25    "REPLS": {
 26      "TYPE" : "f16",
 27    },
 28    "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
 29  },
 30  {
 31   "SHADER_SUFFIX": "f32_ff",
 32    "REPLS": {
 33      "TYPE" : "f32",
 34    },
 35    "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
 36  },
 37  {
 38   "SHADER_SUFFIX": "f32_ff_inplace",
 39    "REPLS": {
 40      "TYPE" : "f32",
 41    },
 42    "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
 43  },
 44  {
 45    "SHADER_SUFFIX": "f16_ff",
 46    "REPLS": {
 47      "TYPE" : "f16",
 48    },
 49    "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
 50  },
 51  {
 52    "SHADER_SUFFIX": "f16_ff_inplace",
 53    "REPLS": {
 54      "TYPE" : "f16",
 55    },
 56    "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
 57  }
 58]
 59
 60#end(VARIANTS)
 61
 62#define(DECLS)
 63
 64#decl(ROTATE)
 65fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
 66    dst[i_dst0] = {{TYPE}}(out0);
 67    dst[i_dst1] = {{TYPE}}(out1);
 68}
 69#enddecl(ROTATE)
 70
 71#decl(ROTATE_INPLACE)
 72fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
 73    src0[i_dst0] = {{TYPE}}(out0);
 74    src0[i_dst1] = {{TYPE}}(out1);
 75}
 76#enddecl(ROTATE_INPLACE)
 77
 78#decl(NO_FF_FUNC)
 79fn freq_factor(i: u32) -> f32 {
 80    return 1.0f;
 81}
 82#enddecl(NO_FF_FUNC)
 83
 84#decl(FF_FUNC)
 85fn freq_factor(i: u32) -> f32 {
 86    return src2[params.offset_src2 + i/2];
 87}
 88#enddecl(FF_FUNC)
 89
 90#decl(NO_FF_BINDINGS)
 91
 92@group(0) @binding(2)
 93var<storage, read_write> dst: array<{{TYPE}}>;
 94
 95@group(0) @binding(3)
 96var<uniform> params: Params;
 97
 98#enddecl(NO_FF_BINDINGS)
 99
100#decl(NO_FF_BINDINGS_INPLACE)
101
102@group(0) @binding(2)
103var<uniform> params: Params;
104
105#enddecl(NO_FF_BINDINGS_INPLACE)
106
107#decl(FF_BINDINGS)
108
109@group(0) @binding(2)
110var<storage, read_write> src2: array<f32>;
111
112@group(0) @binding(3)
113var<storage, read_write> dst: array<{{TYPE}}>;
114
115@group(0) @binding(4)
116var<uniform> params: Params;
117
118#enddecl(FF_BINDINGS)
119
120#decl(FF_BINDINGS_INPLACE)
121
122@group(0) @binding(2)
123var<storage, read_write> src2: array<f32>;
124
125@group(0) @binding(3)
126var<uniform> params: Params;
127
128#enddecl(FF_BINDINGS_INPLACE)
129
130#end(DECLS)
131
132#define(SHADER)
133
134enable f16;
135
136struct Params {
137    offset_src0: u32,
138    offset_src1: u32,
139    offset_src2: u32,
140    offset_dst: u32,
141
142    // Strides (in elements)
143    stride_src01: u32,
144    stride_src02: u32,
145    stride_src03: u32,
146
147    stride_dst1: u32,
148    stride_dst2: u32,
149    stride_dst3: u32,
150
151    n_threads: u32,
152    ne0: u32,
153    ne1: u32,
154    ne2: u32,
155
156    n_dims: u32,
157    mode: u32,
158    theta_scale: f32,
159    attn_factor: f32,
160    freq_scale: f32,
161    ext_factor: f32,
162    corr_dim0: f32,
163    corr_dim1: f32,
164    sections0: u32,
165    sections1: u32,
166    sections2: u32,
167    sections3: u32
168};
169
170@group(0) @binding(0)
171var<storage, read_write> src0: array<{{TYPE}}>;
172
173@group(0) @binding(1)
174var<storage, read_write> src1: array<i32>;
175
176DECLS
177
178fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
179    let y = (f32(i / 2) - low) / max(0.001f, high - low);
180    return 1.0f - min(1.0f, max(0.0f, y));
181}
182
183// returns vector of (cos_theta, sin_theta)
184// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
185fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
186    var mscale = params.attn_factor;
187    var theta = params.freq_scale * theta_extrap;
188    if (params.ext_factor != 0.0f) {
189        let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
190        theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
191        mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
192    }
193    return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
194}
195
196fn pair_base(i0: u32, div_2: bool) -> u32 {
197    if (div_2) {
198        return i0 / 2;
199    } else {
200        return i0;
201    }
202}
203
204fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
205    if (is_vision) {
206        return params.n_dims;
207    } else if (is_neox || is_mrope) {
208        return params.n_dims / 2;
209    } else {
210        return 1;
211    }
212}
213
214override wg_size: u32;
215@compute @workgroup_size(wg_size)
216fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
217    // two elements per thread
218    if (gid.x >= params.n_threads) {
219        return;
220    }
221
222    let is_neox = bool(params.mode & 2);
223    let is_mrope = bool(params.mode & 8);
224    let is_imrope = params.mode == 40;
225    let is_vision = params.mode == 24;
226
227    var i = gid.x * 2; // start index for this thread
228    let i3 = i / (params.ne2 * params.ne1 * params.ne0);
229    i = i % (params.ne2 * params.ne1 * params.ne0);
230    let i2 = i / (params.ne1 * params.ne0);
231    i = i % (params.ne1 * params.ne0);
232    let i1 = i / params.ne0;
233    let i0 = i % params.ne0;
234
235    let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
236    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
237
238    if (i0 >= params.n_dims && !is_vision) {
239        let i_src = i_src_row + i0;
240        let i_dst = i_dst_row + i0;
241        rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
242        return;
243    }
244
245    var theta_base_mult: u32 = 0;
246    var theta_scale_pwr: u32 = i0 / 2;
247    if (is_mrope) {
248        let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
249        let sec_w = params.sections1 + params.sections0;
250        let sec_e = params.sections2 + sec_w;
251        let sector = (i0 / 2) % sect_dims;
252        if (is_imrope) {
253          if (sector % 3 == 1 && sector < 3 * params.sections1) {
254              theta_base_mult = 1;
255          } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
256              theta_base_mult = 2;
257          } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
258              theta_base_mult = 0;
259          } else {
260              theta_base_mult = 3;
261          }
262        } else {
263          if (sector >= params.sections0 && sector < sec_w) {
264              theta_base_mult = 1;
265              if (is_vision) {
266                  theta_scale_pwr = sector - params.sections0;
267              }
268          } else if (sector >= sec_w && sector < sec_e) {
269              theta_base_mult = 2;
270              if (is_vision) {
271                  theta_scale_pwr = sector - sec_w;
272              }
273          } else if (sector >= sec_e) {
274              if (is_vision) {
275                  theta_scale_pwr = sector - sec_e;
276                  theta_scale_pwr = (i0 / 2) % sec_e;
277              }
278              theta_base_mult = 3;
279          } else if (is_vision) {
280              theta_scale_pwr = sector;
281          }
282        }
283    }
284    let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
285    let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
286
287    let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
288    let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
289
290    let x0 = f32(src0[i_src]);
291    let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
292    rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
293}
294
295#end(SHADER)