1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3//------------------------------------------------------------------------------
4// kernel_rope
5//------------------------------------------------------------------------------
6float rope_yarn_ramp(float low, float high, int i0) {
7 const float y = (i0 / 2 - low) / max(0.001f, high - low);
8 return 1.0f - min(1.0f, max(0.0f, y));
9}
10
11// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
12// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
13float2 rope_yarn(
14 float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
15) {
16 // Get n-d rotational scaling corrected for extrapolation
17 float theta_interp = freq_scale * theta_extrap;
18 float theta = theta_interp;
19 if (ext_factor != 0.0f) {
20 float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
21 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
22
23 // Get n-d magnitude scaling corrected for interpolation
24 mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
25 }
26 return (float2)(cos(theta) * mscale, sin(theta) * mscale);
27}
28
29// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
30// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
31float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
32 return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
33}
34
35float2 rope_yarn_corr_dims(
36 int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
37) {
38 // start and end correction dims
39 return (float2)(
40 max(0.0f, floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
41 min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
42 );
43}
44
45kernel void kernel_rope_norm_f32(
46 global void * src0,
47 ulong offset0,
48 global int * src1,
49 ulong offset1,
50 global float * src2,
51 ulong offset2,
52 global float * dst,
53 ulong offsetd,
54 int ne00,
55 int ne01,
56 int ne02,
57 int ne03,
58 ulong nb00,
59 ulong nb01,
60 ulong nb02,
61 ulong nb03,
62 int ne0,
63 int ne1,
64 int ne2,
65 int ne3,
66 ulong nb0,
67 ulong nb1,
68 ulong nb2,
69 ulong nb3,
70 int n_past,
71 int n_dims,
72 int n_ctx_orig,
73 float freq_base,
74 float freq_scale,
75 float ext_factor,
76 float attn_factor,
77 float beta_fast,
78 float beta_slow
79) {
80 src0 = (global void*)((global char*)src0 + offset0);
81 src1 = (global int*)((global char*)src1 + offset1);
82 src2 = (global float*)((global char*)src2 + offset2);
83 dst = (global float*)((global char*)dst + offsetd);
84
85 int i3 = get_group_id(2);
86 int i2 = get_group_id(1);
87 int i1 = get_group_id(0);
88
89 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
90
91 global int * pos = src1;
92
93 float theta_base = (float) pos[i2];
94 float inv_ndims = -1.f/n_dims;
95
96 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
97 if (i0 < n_dims) {
98 int ic = i0/2;
99
100 float theta = theta_base * pow(freq_base, inv_ndims*i0);
101
102 float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
103
104 float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
105
106 global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
107 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
108
109 float x0 = src[0];
110 float x1 = src[1];
111
112 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
113 dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
114 } else {
115 global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
116 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
117
118 dst_data[0] = src[0];
119 dst_data[1] = src[1];
120 }
121 }
122}
123
124kernel void kernel_rope_norm_f16(
125 global void * src0,
126 ulong offset0,
127 global int * src1,
128 ulong offset1,
129 global float * src2,
130 ulong offset2,
131 global float * dst,
132 ulong offsetd,
133 int ne00,
134 int ne01,
135 int ne02,
136 int ne03,
137 ulong nb00,
138 ulong nb01,
139 ulong nb02,
140 ulong nb03,
141 int ne0,
142 int ne1,
143 int ne2,
144 int ne3,
145 ulong nb0,
146 ulong nb1,
147 ulong nb2,
148 ulong nb3,
149 int n_past,
150 int n_dims,
151 int n_ctx_orig,
152 float freq_base,
153 float freq_scale,
154 float ext_factor,
155 float attn_factor,
156 float beta_fast,
157 float beta_slow
158) {
159 src0 = (global void*)((global char*)src0 + offset0);
160 src1 = (global int*)((global char*)src1 + offset1);
161 src2 = (global float*)((global char*)src2 + offset2);
162 dst = (global float*)((global char*)dst + offsetd);
163
164 int i3 = get_group_id(2);
165 int i2 = get_group_id(1);
166 int i1 = get_group_id(0);
167
168 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
169
170 global int * pos = src1;
171
172 float theta_base = (float) pos[i2];
173 float inv_ndims = -1.f/n_dims;
174
175 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
176 if (i0 < n_dims) {
177 int ic = i0/2;
178
179 float theta = theta_base * pow(freq_base, inv_ndims*i0);
180
181 float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
182
183 float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
184
185 global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
186 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
187
188 float x0 = src[0];
189 float x1 = src[1];
190
191 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
192 dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
193 } else {
194 global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
195 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
196
197 dst_data[0] = src[0];
198 dst_data[1] = src[1];
199 }
200 }
201}
202
203kernel void kernel_rope_neox_f32(
204 global void * src0,
205 ulong offset0,
206 global int * src1,
207 ulong offset1,
208 global float * src2,
209 ulong offset2,
210 global float * dst,
211 ulong offsetd,
212 int ne00,
213 int ne01,
214 int ne02,
215 int ne03,
216 ulong nb00,
217 ulong nb01,
218 ulong nb02,
219 ulong nb03,
220 int ne0,
221 int ne1,
222 int ne2,
223 int ne3,
224 ulong nb0,
225 ulong nb1,
226 ulong nb2,
227 ulong nb3,
228 int n_past,
229 int n_dims,
230 int n_ctx_orig,
231 float freq_base,
232 float freq_scale,
233 float ext_factor,
234 float attn_factor,
235 float beta_fast,
236 float beta_slow
237) {
238 src0 = (global void*)((global char*)src0 + offset0);
239 src1 = (global int*)((global char*)src1 + offset1);
240 src2 = (global float*)((global char*)src2 + offset2);
241 dst = (global float*)((global char*)dst + offsetd);
242
243 int i3 = get_group_id(2);
244 int i2 = get_group_id(1);
245 int i1 = get_group_id(0);
246
247 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
248
249 global int * pos = src1;
250
251 float theta_base = (float) pos[i2];
252 float inv_ndims = -1.f/n_dims;
253
254 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
255 if (i0 < n_dims) {
256 int ic = i0/2;
257
258 const float theta = theta_base * pow(freq_base, inv_ndims*i0);
259
260 const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
261
262 float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
263
264 global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
265 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
266
267 const float x0 = src[0];
268 const float x1 = src[n_dims/2];
269
270 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
271 dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
272 } else {
273 global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
274 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
275
276 dst_data[0] = src[0];
277 dst_data[1] = src[1];
278 }
279 }
280}
281
282kernel void kernel_rope_neox_f16(
283 global void * src0,
284 ulong offset0,
285 global int * src1,
286 ulong offset1,
287 global float * src2,
288 ulong offset2,
289 global float * dst,
290 ulong offsetd,
291 int ne00,
292 int ne01,
293 int ne02,
294 int ne03,
295 ulong nb00,
296 ulong nb01,
297 ulong nb02,
298 ulong nb03,
299 int ne0,
300 int ne1,
301 int ne2,
302 int ne3,
303 ulong nb0,
304 ulong nb1,
305 ulong nb2,
306 ulong nb3,
307 int n_past,
308 int n_dims,
309 int n_ctx_orig,
310 float freq_base,
311 float freq_scale,
312 float ext_factor,
313 float attn_factor,
314 float beta_fast,
315 float beta_slow
316) {
317 src0 = (global void*)((global char*)src0 + offset0);
318 src1 = (global int*)((global char*)src1 + offset1);
319 src2 = (global float*)((global char*)src2 + offset2);
320 dst = (global float*)((global char*)dst + offsetd);
321
322 int i3 = get_group_id(2);
323 int i2 = get_group_id(1);
324 int i1 = get_group_id(0);
325
326 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
327
328 global int * pos = src1;
329
330 float theta_base = (float) pos[i2];
331 float inv_ndims = -1.f/n_dims;
332
333 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
334 if (i0 < n_dims) {
335 int ic = i0/2;
336
337 const float theta = theta_base * pow(freq_base, inv_ndims*i0);
338
339 const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
340
341 float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
342
343 global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
344 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
345
346 const float x0 = src[0];
347 const float x1 = src[n_dims/2];
348
349 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
350 dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
351 } else {
352 global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
353 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
354
355 dst_data[0] = src[0];
356 dst_data[1] = src[1];
357 }
358 }
359}
360
361kernel void kernel_rope_multi_f32(
362 global void * src0,
363 ulong offset0,
364 global int * src1,
365 ulong offset1,
366 global float * src2,
367 ulong offset2,
368 global float * dst,
369 ulong offsetd,
370 int ne00,
371 int ne01,
372 int ne02,
373 int ne03,
374 ulong nb00,
375 ulong nb01,
376 ulong nb02,
377 ulong nb03,
378 int ne0,
379 int ne1,
380 int ne2,
381 int ne3,
382 ulong nb0,
383 ulong nb1,
384 ulong nb2,
385 ulong nb3,
386 int n_past,
387 int n_dims,
388 int n_ctx_orig,
389 float freq_base,
390 float freq_scale,
391 float ext_factor,
392 float attn_factor,
393 float beta_fast,
394 float beta_slow,
395 int4 sections,
396 int is_imrope
397) {
398 src0 = (global void*)((global char*)src0 + offset0);
399 src1 = (global int*)((global char*)src1 + offset1);
400 src2 = (global float*)((global char*)src2 + offset2);
401 dst = (global float*)((global char*)dst + offsetd);
402
403 int i3 = get_group_id(2);
404 int i2 = get_group_id(1);
405 int i1 = get_group_id(0);
406
407 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
408
409 global int * pos = src1;
410
411 const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
412 const int sec_w = sections.s1 + sections.s0;
413
414 float inv_ndims = -1.f/n_dims;
415
416 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
417 if (i0 < n_dims) {
418 int ic = i0/2;
419
420 const int sector = (i0 / 2) % sect_dims;
421 float theta_base = 0.0f;
422
423 if (is_imrope) {
424 if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
425 theta_base = (float) pos[i2 + ne02 * 1];
426 } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
427 theta_base = (float) pos[i2 + ne02 * 2];
428 } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
429 theta_base = (float) pos[i2 + ne02 * 0];
430 } else { // e
431 theta_base = (float) pos[i2 + ne02 * 3];
432 }
433 } else {
434 if (sector < sections.s0) {
435 theta_base = pos[i2];
436 }
437 else if (sector >= sections.s0 && sector < sec_w) {
438 theta_base = pos[i2 + ne2 * 1];
439 }
440 else if (sector >= sec_w && sector < sec_w + sections.s2) {
441 theta_base = pos[i2 + ne2 * 2];
442 }
443 else if (sector >= sec_w + sections.s2) {
444 theta_base = pos[i2 + ne2 * 3];
445 }
446 }
447
448 const float theta = theta_base * pow(freq_base, inv_ndims*i0);
449
450 const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
451
452 float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
453
454 global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
455 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
456
457 const float x0 = src[0];
458 const float x1 = src[n_dims/2];
459
460 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
461 dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
462 } else {
463 global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
464 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
465
466 dst_data[0] = src[0];
467 dst_data[1] = src[1];
468 }
469 }
470}
471
472kernel void kernel_rope_multi_f16(
473 global void * src0,
474 ulong offset0,
475 global int * src1,
476 ulong offset1,
477 global float * src2,
478 ulong offset2,
479 global half * dst,
480 ulong offsetd,
481 int ne00,
482 int ne01,
483 int ne02,
484 int ne03,
485 ulong nb00,
486 ulong nb01,
487 ulong nb02,
488 ulong nb03,
489 int ne0,
490 int ne1,
491 int ne2,
492 int ne3,
493 ulong nb0,
494 ulong nb1,
495 ulong nb2,
496 ulong nb3,
497 int n_past,
498 int n_dims,
499 int n_ctx_orig,
500 float freq_base,
501 float freq_scale,
502 float ext_factor,
503 float attn_factor,
504 float beta_fast,
505 float beta_slow,
506 int4 sections,
507 int is_imrope
508) {
509 src0 = (global void*)((global char*)src0 + offset0);
510 src1 = (global int*)((global char*)src1 + offset1);
511 src2 = (global float*)((global char*)src2 + offset2);
512 dst = (global float*)((global char*)dst + offsetd);
513
514 int i3 = get_group_id(2);
515 int i2 = get_group_id(1);
516 int i1 = get_group_id(0);
517
518 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
519
520 global int * pos = src1;
521
522 const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
523 const int sec_w = sections.s1 + sections.s0;
524
525 float inv_ndims = -1.f/n_dims;
526
527 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
528 if (i0 < n_dims) {
529 int ic = i0/2;
530
531 const int sector = (i0 / 2) % sect_dims;
532 float theta_base = 0.0f;
533
534 if (is_imrope) {
535 if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
536 theta_base = (float) pos[i2 + ne02 * 1];
537 } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
538 theta_base = (float) pos[i2 + ne02 * 2];
539 } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
540 theta_base = (float) pos[i2 + ne02 * 0];
541 } else { // e
542 theta_base = (float) pos[i2 + ne02 * 3];
543 }
544 } else {
545 if (sector < sections.s0) {
546 theta_base = pos[i2];
547 }
548 else if (sector >= sections.s0 && sector < sec_w) {
549 theta_base = pos[i2 + ne2 * 1];
550 }
551 else if (sector >= sec_w && sector < sec_w + sections.s2) {
552 theta_base = pos[i2 + ne2 * 2];
553 }
554 else if (sector >= sec_w + sections.s2) {
555 theta_base = pos[i2 + ne2 * 3];
556 }
557 }
558
559 const float theta = theta_base * pow(freq_base, inv_ndims*i0);
560
561 const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
562
563 float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
564
565 global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
566 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
567
568 const float x0 = src[0];
569 const float x1 = src[n_dims/2];
570
571 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
572 dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
573 } else {
574 global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
575 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
576
577 dst_data[0] = src[0];
578 dst_data[1] = src[1];
579 }
580 }
581}
582
583kernel void kernel_rope_vision_f32(
584 global void * src0,
585 ulong offset0,
586 global int * src1,
587 ulong offset1,
588 global float * src2,
589 ulong offset2,
590 global float * dst,
591 ulong offsetd,
592 int ne00,
593 int ne01,
594 int ne02,
595 int ne03,
596 ulong nb00,
597 ulong nb01,
598 ulong nb02,
599 ulong nb03,
600 int ne0,
601 int ne1,
602 int ne2,
603 int ne3,
604 ulong nb0,
605 ulong nb1,
606 ulong nb2,
607 ulong nb3,
608 int n_past,
609 int n_dims,
610 int n_ctx_orig,
611 float freq_base,
612 float freq_scale,
613 float ext_factor,
614 float attn_factor,
615 float beta_fast,
616 float beta_slow,
617 int4 sections
618) {
619 src0 = (global void*)((global char*)src0 + offset0);
620 src1 = (global int*)((global char*)src1 + offset1);
621 src2 = (global float*)((global char*)src2 + offset2);
622 dst = (global float*)((global char*)dst + offsetd);
623
624 int i3 = get_group_id(2);
625 int i2 = get_group_id(1);
626 int i1 = get_group_id(0);
627
628 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
629
630 global int * pos = src1;
631
632 const int sect_dims = sections.s0 + sections.s1;
633 const int sec_w = sections.s1 + sections.s0;
634
635 float inv_ndims = -1.f/n_dims;
636
637 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
638 int ic = i0/2;
639
640 const int sector = (i0/2) % sect_dims;
641 float theta_base = 0.0f;
642
643 if (sector < sections.s0) {
644 const int p = sector;
645 theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
646 } else if (sector >= sections.s0 && sector < sec_w) {
647 const int p = sector - sections.s0;
648 theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
649 }
650
651 const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
652
653 float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
654
655 global float * src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
656 global float * dst_data = (global float *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
657
658 const float x0 = src[0];
659 const float x1 = src[n_dims];
660
661 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
662 dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
663 }
664}
665
666kernel void kernel_rope_vision_f16(
667 global void * src0,
668 ulong offset0,
669 global int * src1,
670 ulong offset1,
671 global float * src2,
672 ulong offset2,
673 global half * dst,
674 ulong offsetd,
675 int ne00,
676 int ne01,
677 int ne02,
678 int ne03,
679 ulong nb00,
680 ulong nb01,
681 ulong nb02,
682 ulong nb03,
683 int ne0,
684 int ne1,
685 int ne2,
686 int ne3,
687 ulong nb0,
688 ulong nb1,
689 ulong nb2,
690 ulong nb3,
691 int n_past,
692 int n_dims,
693 int n_ctx_orig,
694 float freq_base,
695 float freq_scale,
696 float ext_factor,
697 float attn_factor,
698 float beta_fast,
699 float beta_slow,
700 int4 sections
701) {
702 src0 = (global void*)((global char*)src0 + offset0);
703 src1 = (global int*)((global char*)src1 + offset1);
704 src2 = (global float*)((global char*)src2 + offset2);
705 dst = (global float*)((global char*)dst + offsetd);
706
707 int i3 = get_group_id(2);
708 int i2 = get_group_id(1);
709 int i1 = get_group_id(0);
710
711 float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
712
713 global int * pos = src1;
714
715 const int sect_dims = sections.s0 + sections.s1;
716 const int sec_w = sections.s1 + sections.s0;
717
718 float inv_ndims = -1.f/n_dims;
719
720 for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
721 int ic = i0/2;
722
723 const int sector = (i0/2) % sect_dims;
724 float theta_base = 0.0f;
725
726 if (sector < sections.s0) {
727 const int p = sector;
728 theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
729 } else if (sector >= sections.s0 && sector < sec_w) {
730 const int p = sector - sections.s0;
731 theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
732 }
733
734 const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
735
736 float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
737
738 global half * src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
739 global half * dst_data = (global half *)((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
740
741 const float x0 = src[0];
742 const float x1 = src[n_dims];
743
744 dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
745 dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
746 }
747}