1#define(VARIANTS)
2
3[
4 {
5 "REPLS": {
6 "TYPE" : "f32",
7 },
8 "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
9 },
10 {
11 "SHADER_SUFFIX": "f32_inplace",
12 "REPLS": {
13 "TYPE" : "f32",
14 },
15 "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
16 },
17 {
18 "REPLS": {
19 "TYPE" : "f16",
20 },
21 "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
22 },
23 {
24 "SHADER_SUFFIX": "f16_inplace",
25 "REPLS": {
26 "TYPE" : "f16",
27 },
28 "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
29 },
30 {
31 "SHADER_SUFFIX": "f32_ff",
32 "REPLS": {
33 "TYPE" : "f32",
34 },
35 "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
36 },
37 {
38 "SHADER_SUFFIX": "f32_ff_inplace",
39 "REPLS": {
40 "TYPE" : "f32",
41 },
42 "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
43 },
44 {
45 "SHADER_SUFFIX": "f16_ff",
46 "REPLS": {
47 "TYPE" : "f16",
48 },
49 "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
50 },
51 {
52 "SHADER_SUFFIX": "f16_ff_inplace",
53 "REPLS": {
54 "TYPE" : "f16",
55 },
56 "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
57 }
58]
59
60#end(VARIANTS)
61
62#define(DECLS)
63
64#decl(ROTATE)
65fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
66 dst[i_dst0] = {{TYPE}}(out0);
67 dst[i_dst1] = {{TYPE}}(out1);
68}
69#enddecl(ROTATE)
70
71#decl(ROTATE_INPLACE)
72fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
73 src0[i_dst0] = {{TYPE}}(out0);
74 src0[i_dst1] = {{TYPE}}(out1);
75}
76#enddecl(ROTATE_INPLACE)
77
78#decl(NO_FF_FUNC)
79fn freq_factor(i: u32) -> f32 {
80 return 1.0f;
81}
82#enddecl(NO_FF_FUNC)
83
84#decl(FF_FUNC)
85fn freq_factor(i: u32) -> f32 {
86 return src2[params.offset_src2 + i/2];
87}
88#enddecl(FF_FUNC)
89
90#decl(NO_FF_BINDINGS)
91
92@group(0) @binding(2)
93var<storage, read_write> dst: array<{{TYPE}}>;
94
95@group(0) @binding(3)
96var<uniform> params: Params;
97
98#enddecl(NO_FF_BINDINGS)
99
100#decl(NO_FF_BINDINGS_INPLACE)
101
102@group(0) @binding(2)
103var<uniform> params: Params;
104
105#enddecl(NO_FF_BINDINGS_INPLACE)
106
107#decl(FF_BINDINGS)
108
109@group(0) @binding(2)
110var<storage, read_write> src2: array<f32>;
111
112@group(0) @binding(3)
113var<storage, read_write> dst: array<{{TYPE}}>;
114
115@group(0) @binding(4)
116var<uniform> params: Params;
117
118#enddecl(FF_BINDINGS)
119
120#decl(FF_BINDINGS_INPLACE)
121
122@group(0) @binding(2)
123var<storage, read_write> src2: array<f32>;
124
125@group(0) @binding(3)
126var<uniform> params: Params;
127
128#enddecl(FF_BINDINGS_INPLACE)
129
130#end(DECLS)
131
132#define(SHADER)
133
134enable f16;
135
136struct Params {
137 offset_src0: u32,
138 offset_src1: u32,
139 offset_src2: u32,
140 offset_dst: u32,
141
142 // Strides (in elements)
143 stride_src01: u32,
144 stride_src02: u32,
145 stride_src03: u32,
146
147 stride_dst1: u32,
148 stride_dst2: u32,
149 stride_dst3: u32,
150
151 n_threads: u32,
152 ne0: u32,
153 ne1: u32,
154 ne2: u32,
155
156 n_dims: u32,
157 mode: u32,
158 theta_scale: f32,
159 attn_factor: f32,
160 freq_scale: f32,
161 ext_factor: f32,
162 corr_dim0: f32,
163 corr_dim1: f32,
164 sections0: u32,
165 sections1: u32,
166 sections2: u32,
167 sections3: u32
168};
169
170@group(0) @binding(0)
171var<storage, read_write> src0: array<{{TYPE}}>;
172
173@group(0) @binding(1)
174var<storage, read_write> src1: array<i32>;
175
176DECLS
177
178fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
179 let y = (f32(i / 2) - low) / max(0.001f, high - low);
180 return 1.0f - min(1.0f, max(0.0f, y));
181}
182
183// returns vector of (cos_theta, sin_theta)
184// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
185fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
186 var mscale = params.attn_factor;
187 var theta = params.freq_scale * theta_extrap;
188 if (params.ext_factor != 0.0f) {
189 let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
190 theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
191 mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
192 }
193 return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
194}
195
196fn pair_base(i0: u32, div_2: bool) -> u32 {
197 if (div_2) {
198 return i0 / 2;
199 } else {
200 return i0;
201 }
202}
203
204fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
205 if (is_vision) {
206 return params.n_dims;
207 } else if (is_neox || is_mrope) {
208 return params.n_dims / 2;
209 } else {
210 return 1;
211 }
212}
213
214override wg_size: u32;
215@compute @workgroup_size(wg_size)
216fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
217 // two elements per thread
218 if (gid.x >= params.n_threads) {
219 return;
220 }
221
222 let is_neox = bool(params.mode & 2);
223 let is_mrope = bool(params.mode & 8);
224 let is_imrope = params.mode == 40;
225 let is_vision = params.mode == 24;
226
227 var i = gid.x * 2; // start index for this thread
228 let i3 = i / (params.ne2 * params.ne1 * params.ne0);
229 i = i % (params.ne2 * params.ne1 * params.ne0);
230 let i2 = i / (params.ne1 * params.ne0);
231 i = i % (params.ne1 * params.ne0);
232 let i1 = i / params.ne0;
233 let i0 = i % params.ne0;
234
235 let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
236 let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
237
238 if (i0 >= params.n_dims && !is_vision) {
239 let i_src = i_src_row + i0;
240 let i_dst = i_dst_row + i0;
241 rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
242 return;
243 }
244
245 var theta_base_mult: u32 = 0;
246 var theta_scale_pwr: u32 = i0 / 2;
247 if (is_mrope) {
248 let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
249 let sec_w = params.sections1 + params.sections0;
250 let sec_e = params.sections2 + sec_w;
251 let sector = (i0 / 2) % sect_dims;
252 if (is_imrope) {
253 if (sector % 3 == 1 && sector < 3 * params.sections1) {
254 theta_base_mult = 1;
255 } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
256 theta_base_mult = 2;
257 } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
258 theta_base_mult = 0;
259 } else {
260 theta_base_mult = 3;
261 }
262 } else {
263 if (sector >= params.sections0 && sector < sec_w) {
264 theta_base_mult = 1;
265 if (is_vision) {
266 theta_scale_pwr = sector - params.sections0;
267 }
268 } else if (sector >= sec_w && sector < sec_e) {
269 theta_base_mult = 2;
270 if (is_vision) {
271 theta_scale_pwr = sector - sec_w;
272 }
273 } else if (sector >= sec_e) {
274 if (is_vision) {
275 theta_scale_pwr = sector - sec_e;
276 theta_scale_pwr = (i0 / 2) % sec_e;
277 }
278 theta_base_mult = 3;
279 } else if (is_vision) {
280 theta_scale_pwr = sector;
281 }
282 }
283 }
284 let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
285 let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
286
287 let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
288 let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
289
290 let x0 = f32(src0[i_src]);
291 let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
292 rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
293}
294
295#end(SHADER)