1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2
  3//------------------------------------------------------------------------------
  4// kernel_rope
  5//------------------------------------------------------------------------------
  6float rope_yarn_ramp(float low, float high, int i0) {
  7    const float y = (i0 / 2 - low) / max(0.001f, high - low);
  8    return 1.0f - min(1.0f, max(0.0f, y));
  9}
 10
 11// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 12// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 13float2 rope_yarn(
 14    float theta_extrap, float freq_scale, float2 corr_dims, int i0, float ext_factor, float mscale
 15) {
 16    // Get n-d rotational scaling corrected for extrapolation
 17    float theta_interp = freq_scale * theta_extrap;
 18    float theta = theta_interp;
 19    if (ext_factor != 0.0f) {
 20        float ramp_mix = rope_yarn_ramp(corr_dims.s0, corr_dims.s1, i0) * ext_factor;
 21        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 22
 23        // Get n-d magnitude scaling corrected for interpolation
 24        mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
 25    }
 26    return (float2)(cos(theta) * mscale, sin(theta) * mscale);
 27}
 28
 29// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
 30// `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
 31float rope_yarn_corr_factor(int n_dims, int n_ctx_orig, float n_rot, float base) {
 32    return n_dims * log(n_ctx_orig / (n_rot * 2 * M_PI_F)) / (2 * log(base));
 33}
 34
 35float2 rope_yarn_corr_dims(
 36    int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow
 37) {
 38    // start and end correction dims
 39    return (float2)(
 40        max(0.0f,         floor(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_fast, freq_base))),
 41        min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base)))
 42    );
 43}
 44
 45kernel void kernel_rope_norm_f32(
 46        global void * src0,
 47        ulong offset0,
 48        global int * src1,
 49        ulong offset1,
 50        global float * src2,
 51        ulong offset2,
 52        global float * dst,
 53        ulong offsetd,
 54        int ne00,
 55        int ne01,
 56        int ne02,
 57        int ne03,
 58        ulong nb00,
 59        ulong nb01,
 60        ulong nb02,
 61        ulong nb03,
 62        int ne0,
 63        int ne1,
 64        int ne2,
 65        int ne3,
 66        ulong nb0,
 67        ulong nb1,
 68        ulong nb2,
 69        ulong nb3,
 70        int n_past,
 71        int n_dims,
 72        int n_ctx_orig,
 73        float freq_base,
 74        float freq_scale,
 75        float ext_factor,
 76        float attn_factor,
 77        float beta_fast,
 78        float beta_slow
 79) {
 80    src0 = (global void*)((global char*)src0 + offset0);
 81    src1 = (global int*)((global char*)src1 + offset1);
 82    src2 = (global float*)((global char*)src2 + offset2);
 83    dst = (global float*)((global char*)dst + offsetd);
 84
 85    int i3 = get_group_id(2);
 86    int i2 = get_group_id(1);
 87    int i1 = get_group_id(0);
 88
 89    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
 90
 91    global int * pos = src1;
 92
 93    float theta_base = (float) pos[i2];
 94    float inv_ndims = -1.f/n_dims;
 95
 96    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
 97        if (i0 < n_dims) {
 98            int ic = i0/2;
 99
100            float theta = theta_base * pow(freq_base, inv_ndims*i0);
101
102            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
103
104            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
105
106            global float * src       = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
107            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
108
109            float x0 = src[0];
110            float x1 = src[1];
111
112            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
113            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
114        } else {
115            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
116            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
117
118            dst_data[0] = src[0];
119            dst_data[1] = src[1];
120        }
121    }
122}
123
124kernel void kernel_rope_norm_f16(
125        global void * src0,
126        ulong offset0,
127        global int * src1,
128        ulong offset1,
129        global float * src2,
130        ulong offset2,
131        global float * dst,
132        ulong offsetd,
133        int ne00,
134        int ne01,
135        int ne02,
136        int ne03,
137        ulong nb00,
138        ulong nb01,
139        ulong nb02,
140        ulong nb03,
141        int ne0,
142        int ne1,
143        int ne2,
144        int ne3,
145        ulong nb0,
146        ulong nb1,
147        ulong nb2,
148        ulong nb3,
149        int n_past,
150        int n_dims,
151        int n_ctx_orig,
152        float freq_base,
153        float freq_scale,
154        float ext_factor,
155        float attn_factor,
156        float beta_fast,
157        float beta_slow
158) {
159    src0 = (global void*)((global char*)src0 + offset0);
160    src1 = (global int*)((global char*)src1 + offset1);
161    src2 = (global float*)((global char*)src2 + offset2);
162    dst = (global float*)((global char*)dst + offsetd);
163
164    int i3 = get_group_id(2);
165    int i2 = get_group_id(1);
166    int i1 = get_group_id(0);
167
168    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
169
170    global int * pos = src1;
171
172    float theta_base = (float) pos[i2];
173    float inv_ndims = -1.f/n_dims;
174
175    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
176        if (i0 < n_dims) {
177            int ic = i0/2;
178
179            float theta = theta_base * pow(freq_base, inv_ndims*i0);
180
181            float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
182
183            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
184
185            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
186            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
187
188            float x0 = src[0];
189            float x1 = src[1];
190
191            dst_data[0] = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
192            dst_data[1] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
193        } else {
194            global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
195            global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
196
197            dst_data[0] = src[0];
198            dst_data[1] = src[1];
199        }
200    }
201}
202
203kernel void kernel_rope_neox_f32(
204        global void * src0,
205        ulong offset0,
206        global int * src1,
207        ulong offset1,
208        global float * src2,
209        ulong offset2,
210        global float * dst,
211        ulong offsetd,
212        int ne00,
213        int ne01,
214        int ne02,
215        int ne03,
216        ulong nb00,
217        ulong nb01,
218        ulong nb02,
219        ulong nb03,
220        int ne0,
221        int ne1,
222        int ne2,
223        int ne3,
224        ulong nb0,
225        ulong nb1,
226        ulong nb2,
227        ulong nb3,
228        int n_past,
229        int n_dims,
230        int n_ctx_orig,
231        float freq_base,
232        float freq_scale,
233        float ext_factor,
234        float attn_factor,
235        float beta_fast,
236        float beta_slow
237) {
238    src0 = (global void*)((global char*)src0 + offset0);
239    src1 = (global int*)((global char*)src1 + offset1);
240    src2 = (global float*)((global char*)src2 + offset2);
241    dst = (global float*)((global char*)dst + offsetd);
242
243    int i3 = get_group_id(2);
244    int i2 = get_group_id(1);
245    int i1 = get_group_id(0);
246
247    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
248
249    global int * pos = src1;
250
251    float theta_base = (float) pos[i2];
252    float inv_ndims = -1.f/n_dims;
253
254    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
255        if (i0 < n_dims) {
256            int ic = i0/2;
257
258            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
259
260            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
261
262            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
263
264            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
265            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
266
267            const float x0 = src[0];
268            const float x1 = src[n_dims/2];
269
270            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
271            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
272        } else {
273            global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
274            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
275
276            dst_data[0] = src[0];
277            dst_data[1] = src[1];
278        }
279    }
280}
281
282kernel void kernel_rope_neox_f16(
283        global void * src0,
284        ulong offset0,
285        global int * src1,
286        ulong offset1,
287        global float * src2,
288        ulong offset2,
289        global float * dst,
290        ulong offsetd,
291        int ne00,
292        int ne01,
293        int ne02,
294        int ne03,
295        ulong nb00,
296        ulong nb01,
297        ulong nb02,
298        ulong nb03,
299        int ne0,
300        int ne1,
301        int ne2,
302        int ne3,
303        ulong nb0,
304        ulong nb1,
305        ulong nb2,
306        ulong nb3,
307        int n_past,
308        int n_dims,
309        int n_ctx_orig,
310        float freq_base,
311        float freq_scale,
312        float ext_factor,
313        float attn_factor,
314        float beta_fast,
315        float beta_slow
316) {
317    src0 = (global void*)((global char*)src0 + offset0);
318    src1 = (global int*)((global char*)src1 + offset1);
319    src2 = (global float*)((global char*)src2 + offset2);
320    dst = (global float*)((global char*)dst + offsetd);
321
322    int i3 = get_group_id(2);
323    int i2 = get_group_id(1);
324    int i1 = get_group_id(0);
325
326    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
327
328    global int * pos = src1;
329
330    float theta_base = (float) pos[i2];
331    float inv_ndims = -1.f/n_dims;
332
333    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
334        if (i0 < n_dims) {
335            int ic = i0/2;
336
337            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
338
339            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
340
341            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
342
343            global half * src       = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
344            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
345
346            const float x0 = src[0];
347            const float x1 = src[n_dims/2];
348
349            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
350            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
351        } else {
352            global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
353            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
354
355            dst_data[0] = src[0];
356            dst_data[1] = src[1];
357        }
358    }
359}
360
361kernel void kernel_rope_multi_f32(
362        global void * src0,
363        ulong offset0,
364        global int * src1,
365        ulong offset1,
366        global float * src2,
367        ulong offset2,
368        global float * dst,
369        ulong offsetd,
370        int ne00,
371        int ne01,
372        int ne02,
373        int ne03,
374        ulong nb00,
375        ulong nb01,
376        ulong nb02,
377        ulong nb03,
378        int ne0,
379        int ne1,
380        int ne2,
381        int ne3,
382        ulong nb0,
383        ulong nb1,
384        ulong nb2,
385        ulong nb3,
386        int n_past,
387        int n_dims,
388        int n_ctx_orig,
389        float freq_base,
390        float freq_scale,
391        float ext_factor,
392        float attn_factor,
393        float beta_fast,
394        float beta_slow,
395        int4 sections,
396        int  is_imrope
397) {
398    src0 = (global void*)((global char*)src0 + offset0);
399    src1 = (global int*)((global char*)src1 + offset1);
400    src2 = (global float*)((global char*)src2 + offset2);
401    dst = (global float*)((global char*)dst + offsetd);
402
403    int i3 = get_group_id(2);
404    int i2 = get_group_id(1);
405    int i1 = get_group_id(0);
406
407    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
408
409    global int * pos = src1;
410
411    const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
412    const int sec_w = sections.s1 + sections.s0;
413
414    float inv_ndims = -1.f/n_dims;
415
416    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
417        if (i0 < n_dims) {
418            int ic = i0/2;
419
420            const int sector = (i0 / 2) % sect_dims;
421            float theta_base = 0.0f;
422
423            if (is_imrope) {
424                if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
425                    theta_base = (float) pos[i2 + ne02 * 1];
426                } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
427                    theta_base = (float) pos[i2 + ne02 * 2];
428                } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
429                    theta_base = (float) pos[i2 + ne02 * 0];
430                } else { // e
431                    theta_base = (float) pos[i2 + ne02 * 3];
432                }
433            } else {
434                if (sector < sections.s0) {
435                    theta_base = pos[i2];
436                }
437                else if (sector >= sections.s0 && sector < sec_w) {
438                    theta_base = pos[i2 + ne2 * 1];
439                }
440                else if (sector >= sec_w && sector < sec_w + sections.s2) {
441                    theta_base = pos[i2 + ne2 * 2];
442                }
443                else if (sector >= sec_w + sections.s2) {
444                    theta_base = pos[i2 + ne2 * 3];
445                }
446            }
447
448            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
449
450            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
451
452            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
453
454            global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
455            global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
456
457            const float x0 = src[0];
458            const float x1 = src[n_dims/2];
459
460            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
461            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
462        } else {
463            global float * const src = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
464            global float * dst_data  = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
465
466            dst_data[0] = src[0];
467            dst_data[1] = src[1];
468        }
469    }
470}
471
472kernel void kernel_rope_multi_f16(
473        global void * src0,
474        ulong offset0,
475        global int * src1,
476        ulong offset1,
477        global float * src2,
478        ulong offset2,
479        global half * dst,
480        ulong offsetd,
481        int ne00,
482        int ne01,
483        int ne02,
484        int ne03,
485        ulong nb00,
486        ulong nb01,
487        ulong nb02,
488        ulong nb03,
489        int ne0,
490        int ne1,
491        int ne2,
492        int ne3,
493        ulong nb0,
494        ulong nb1,
495        ulong nb2,
496        ulong nb3,
497        int n_past,
498        int n_dims,
499        int n_ctx_orig,
500        float freq_base,
501        float freq_scale,
502        float ext_factor,
503        float attn_factor,
504        float beta_fast,
505        float beta_slow,
506        int4 sections,
507        int  is_imrope
508) {
509    src0 = (global void*)((global char*)src0 + offset0);
510    src1 = (global int*)((global char*)src1 + offset1);
511    src2 = (global float*)((global char*)src2 + offset2);
512    dst = (global float*)((global char*)dst + offsetd);
513
514    int i3 = get_group_id(2);
515    int i2 = get_group_id(1);
516    int i1 = get_group_id(0);
517
518    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
519
520    global int * pos = src1;
521
522    const int sect_dims = sections.s0 + sections.s1 + sections.s2 + sections.s3;
523    const int sec_w = sections.s1 + sections.s0;
524
525    float inv_ndims = -1.f/n_dims;
526
527    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
528        if (i0 < n_dims) {
529            int ic = i0/2;
530
531            const int sector = (i0 / 2) % sect_dims;
532            float theta_base = 0.0f;
533
534            if (is_imrope) {
535                if (sector % 3 == 1 && sector < 3 * sections.s1) { // h
536                    theta_base = (float) pos[i2 + ne02 * 1];
537                } else if (sector % 3 == 2 && sector < 3 * sections.s2) { // w
538                    theta_base = (float) pos[i2 + ne02 * 2];
539                } else if (sector % 3 == 0 && sector < 3 * sections.s0) { // t
540                    theta_base = (float) pos[i2 + ne02 * 0];
541                } else { // e
542                    theta_base = (float) pos[i2 + ne02 * 3];
543                }
544            } else {
545                if (sector < sections.s0) {
546                    theta_base = pos[i2];
547                }
548                else if (sector >= sections.s0 && sector < sec_w) {
549                    theta_base = pos[i2 + ne2 * 1];
550                }
551                else if (sector >= sec_w && sector < sec_w + sections.s2) {
552                    theta_base = pos[i2 + ne2 * 2];
553                }
554                else if (sector >= sec_w + sections.s2) {
555                    theta_base = pos[i2 + ne2 * 3];
556                }
557            }
558
559            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
560
561            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
562
563            float2 cos_sin_theta = rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
564
565            global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
566            global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
567
568            const float x0 = src[0];
569            const float x1 = src[n_dims/2];
570
571            dst_data[0]        = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
572            dst_data[n_dims/2] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
573        } else {
574            global half * const src = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
575            global half * dst_data  = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
576
577            dst_data[0] = src[0];
578            dst_data[1] = src[1];
579        }
580    }
581}
582
583kernel void kernel_rope_vision_f32(
584        global void * src0,
585        ulong offset0,
586        global int * src1,
587        ulong offset1,
588        global float * src2,
589        ulong offset2,
590        global float * dst,
591        ulong offsetd,
592        int ne00,
593        int ne01,
594        int ne02,
595        int ne03,
596        ulong nb00,
597        ulong nb01,
598        ulong nb02,
599        ulong nb03,
600        int ne0,
601        int ne1,
602        int ne2,
603        int ne3,
604        ulong nb0,
605        ulong nb1,
606        ulong nb2,
607        ulong nb3,
608        int n_past,
609        int n_dims,
610        int n_ctx_orig,
611        float freq_base,
612        float freq_scale,
613        float ext_factor,
614        float attn_factor,
615        float beta_fast,
616        float beta_slow,
617        int4 sections
618) {
619    src0 = (global void*)((global char*)src0 + offset0);
620    src1 = (global int*)((global char*)src1 + offset1);
621    src2 = (global float*)((global char*)src2 + offset2);
622    dst = (global float*)((global char*)dst + offsetd);
623
624    int i3 = get_group_id(2);
625    int i2 = get_group_id(1);
626    int i1 = get_group_id(0);
627
628    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
629
630    global int * pos = src1;
631
632    const int sect_dims = sections.s0 + sections.s1;
633    const int sec_w = sections.s1 + sections.s0;
634
635    float inv_ndims = -1.f/n_dims;
636
637    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
638        int ic = i0/2;
639
640        const int sector = (i0/2) % sect_dims;
641        float theta_base = 0.0f;
642
643        if (sector < sections.s0) {
644            const int p = sector;
645            theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
646        } else if (sector >= sections.s0 && sector < sec_w) {
647            const int p = sector - sections.s0;
648            theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
649        }
650
651        const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
652
653        float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
654
655        global float * src      = (global float *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
656        global float * dst_data = (global float *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
657
658        const float x0 = src[0];
659        const float x1 = src[n_dims];
660
661        dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
662        dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
663    }
664}
665
666kernel void kernel_rope_vision_f16(
667        global void * src0,
668        ulong offset0,
669        global int * src1,
670        ulong offset1,
671        global float * src2,
672        ulong offset2,
673        global half * dst,
674        ulong offsetd,
675        int ne00,
676        int ne01,
677        int ne02,
678        int ne03,
679        ulong nb00,
680        ulong nb01,
681        ulong nb02,
682        ulong nb03,
683        int ne0,
684        int ne1,
685        int ne2,
686        int ne3,
687        ulong nb0,
688        ulong nb1,
689        ulong nb2,
690        ulong nb3,
691        int n_past,
692        int n_dims,
693        int n_ctx_orig,
694        float freq_base,
695        float freq_scale,
696        float ext_factor,
697        float attn_factor,
698        float beta_fast,
699        float beta_slow,
700        int4 sections
701) {
702    src0 = (global void*)((global char*)src0 + offset0);
703    src1 = (global int*)((global char*)src1 + offset1);
704    src2 = (global float*)((global char*)src2 + offset2);
705    dst = (global float*)((global char*)dst + offsetd);
706
707    int i3 = get_group_id(2);
708    int i2 = get_group_id(1);
709    int i1 = get_group_id(0);
710
711    float2 corr_dims = rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow);
712
713    global int * pos = src1;
714
715    const int sect_dims = sections.s0 + sections.s1;
716    const int sec_w = sections.s1 + sections.s0;
717
718    float inv_ndims = -1.f/n_dims;
719
720    for (int i0 = 2*get_local_id(0); i0 < ne0; i0 += 2*get_local_size(0)) {
721        int ic = i0/2;
722
723        const int sector = (i0/2) % sect_dims;
724        float theta_base = 0.0f;
725
726        if (sector < sections.s0) {
727            const int p = sector;
728            theta_base = pos[i2] * pow(freq_base, inv_ndims*2.0f*p);
729        } else if (sector >= sections.s0 && sector < sec_w) {
730            const int p = sector - sections.s0;
731            theta_base = pos[i2 + ne2] * pow(freq_base, inv_ndims*2.0f*p);
732        }
733
734        const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
735
736        float2 cos_sin_theta = rope_yarn(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor);
737
738        global half * src      = (global half *)((global char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
739        global half * dst_data = (global half *)((global char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
740
741        const float x0 = src[0];
742        const float x1 = src[n_dims];
743
744        dst_data[0]      = x0*cos_sin_theta.s0 - x1*cos_sin_theta.s1;
745        dst_data[n_dims] = x0*cos_sin_theta.s1 + x1*cos_sin_theta.s0;
746    }
747}