1#include "rope.hpp"
  2#include "ggml-sycl/common.hpp"
  3#include "ggml.h"
  4
  5struct rope_corr_dims {
  6    float v[2];
  7};
  8
  9struct mrope_sections {
 10    int v[4];
 11};
 12
 13static float rope_yarn_ramp(const float low, const float high, const int i0) {
 14    const float y = (i0 / 2 - low) / sycl::max(0.001f, high - low);
 15    return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y));
 16}
 17
 18// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 19// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 20static void rope_yarn(
 21    float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
 22    float * cos_theta, float * sin_theta) {
 23    // Get n-d rotational scaling corrected for extrapolation
 24    float theta_interp = freq_scale * theta_extrap;
 25    float theta = theta_interp;
 26    if (ext_factor != 0.0f) {
 27        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
 28        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 29
 30        // Get n-d magnitude scaling corrected for interpolation
 31        mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale);
 32    }
 33    *cos_theta = sycl::cos(theta) * mscale;
 34    *sin_theta = sycl::sin(theta) * mscale;
 35}
 36
 37template <typename T, bool has_ff>
 38static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
 39                      const int32_t * pos, float freq_scale, float ext_factor, float attn_factor,
 40                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
 41                      const sycl::nd_item<3> & item_ct1) {
 42    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 43
 44    if (i0 >= ne0) {
 45        return;
 46    }
 47
 48    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 49
 50    const int row0     = row % ne1;
 51    const int channel0 = row / ne1;
 52
 53    const int i  = row * ne0 + i0;
 54    const int i2 = channel0 * s2 + row0 * s1 + i0;
 55
 56    if (i0 >= n_dims) {
 57        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2);
 58        return;
 59    }
 60
 61    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
 62
 63    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
 64
 65    float cos_theta;
 66    float sin_theta;
 67
 68    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
 69
 70    const float x0 = x[i2 + 0];
 71    const float x1 = x[i2 + 1];
 72
 73    dst[i + 0] = x0 * cos_theta - x1 * sin_theta;
 74    dst[i + 1] = x0 * sin_theta + x1 * cos_theta;
 75}
 76
 77template <typename T, bool has_ff>
 78static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims,
 79                      const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
 80                      const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors,
 81                      const sycl::nd_item<3> & item_ct1) {
 82    const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1));
 83
 84    if (i0 >= ne0) {
 85        return;
 86    }
 87
 88    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 89
 90    const int row0     = row % ne1;
 91    const int channel0 = row / ne1;
 92
 93    const int i  = row * ne0 + i0 / 2;
 94    const int i2 = channel0 * s2 + row0 * s1 + i0 / 2;
 95
 96    if (i0 >= n_dims) {
 97        *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2);
 98        return;
 99    }
100
101    const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f);
102
103    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
104
105    float cos_theta;
106    float sin_theta;
107
108    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
109
110    const float x0 = x[i2 + 0];
111    const float x1 = x[i2 + n_dims / 2];
112
113    dst[i + 0]          = x0 * cos_theta - x1 * sin_theta;
114    dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
115}
116
117template <typename T, bool has_ff>
118static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
119                        const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
120                        const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
121                        const float theta_scale, const float * freq_factors, const mrope_sections sections,
122                        const bool is_imrope, const sycl::nd_item<3> & item_ct1) {
123    // get index pos
124    const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
125    if (i0 >= ne0) {
126        return;
127    }
128    const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
129
130    const int    row_x     = row_dst % ne1;
131    const int    channel_x = row_dst / ne1;
132    const int    idst      = (row_dst * ne0) + (i0 / 2);
133    const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
134
135    if (i0 >= n_dims) {
136        *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix);
137        return;
138    }
139
140    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
141    const int sec_w = sections.v[1] + sections.v[0];
142    const int sector = (i0 / 2) % sect_dims;
143
144
145    float theta_base = 0.0;
146    if (is_imrope) {
147        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {
148            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
149        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {
150            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
151        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {
152            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
153        } else {
154            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
155        }
156    } else {
157        if (sector < sections.v[0]) {
158            theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f);
159        }
160        else if (sector >= sections.v[0] && sector < sec_w) {
161            theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f);
162        }
163        else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
164            theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f);
165        }
166        else if (sector >= sec_w + sections.v[2]) {
167            theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f);
168        }
169    }
170
171    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
172    float       cos_theta;
173    float       sin_theta;
174    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
175    const float x0 = x[ix + 0];
176    const float x1 = x[ix + n_dims/2];
177
178    // store results in dst
179    dst[idst + 0]      = x0 * cos_theta - x1 * sin_theta;
180    dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta;
181}
182
183
184
185template <typename T, bool has_ff>
186static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
187                        const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale,
188                        const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
189                        const float theta_scale, const float * freq_factors, const mrope_sections sections,
190                        const sycl::nd_item<3> & item_ct1) {
191    // get index pos
192    const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1));
193    if (i0 >= ne0) {
194        return;
195    }
196    const int    row_dst   = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2);
197    const int    row_x     = row_dst % ne1;
198    const int    channel_x = row_dst / ne1;
199    const int    idst      = (row_dst * ne0) + (i0 / 2);
200    const size_t ix        = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2);
201
202    const int sect_dims = sections.v[0] + sections.v[1];
203    const int sector    = (i0 / 2) % sect_dims;
204
205    float theta_base = 0.0f;
206    if (sector < sections.v[0]) {
207        const int p = sector;
208        theta_base  = pos[channel_x] * sycl::pow(theta_scale, (float) p);
209    } else {
210        const int p = sector - sections.v[0];
211        theta_base  = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
212    }
213
214    const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f;
215    float       cos_theta;
216    float       sin_theta;
217    rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
218    const float x0 = x[ix + 0];
219    const float x1 = x[ix + n_dims];
220
221    // store results in dst
222    dst[idst + 0]      = x0 * cos_theta - x1 * sin_theta;
223    dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta;
224}
225
226template <typename T>
227static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
228                           const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base,
229                           const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
230                           const float * freq_factors, queue_ptr stream) {
231    GGML_ASSERT(ne0 % 2 == 0);
232    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
233    const int            num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
234    const sycl::range<3> block_nums(1, num_blocks_x, nr);
235
236    const float theta_scale = powf(freq_base, -2.0f / n_dims);
237
238    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
239
240    if (freq_factors == nullptr) {
241        /*
242        DPCT1049:40: The work-group size passed to the SYCL kernel may exceed
243        the limit. To get the device limit, query
244        info::device::max_work_group_size. Adjust the work-group size if needed.
245        */
246        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
247            rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
248                                theta_scale, freq_factors, item_ct1);
249        });
250    } else {
251        /*
252        DPCT1049:41: The work-group size passed to the SYCL kernel may exceed
253        the limit. To get the device limit, query
254        info::device::max_work_group_size. Adjust the work-group size if needed.
255        */
256        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
257            rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
258                               theta_scale, freq_factors, item_ct1);
259        });
260    }
261}
262
263template <typename T>
264static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2,
265                           const int n_dims, const int nr, const int32_t * pos, const float freq_scale,
266                           const float freq_base, const float ext_factor, const float attn_factor,
267                           const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) {
268    GGML_ASSERT(ne0 % 2 == 0);
269    const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
270    const int            num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
271    const sycl::range<3> block_nums(1, num_blocks_x, nr);
272
273    const float theta_scale = powf(freq_base, -2.0f / n_dims);
274
275    dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
276
277    if (freq_factors == nullptr) {
278        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
279            rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
280                                theta_scale, freq_factors, item_ct1);
281        });
282    } else {
283        stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
284            rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims,
285                               theta_scale, freq_factors, item_ct1);
286        });
287    }
288}
289
290template <typename T>
291static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
292                             const size_t s2, const int n_dims, const int nr, const int32_t * pos,
293                             const float freq_scale, const float freq_base, const float ext_factor,
294                             const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
295                             const mrope_sections sections, const bool is_imrope, queue_ptr stream) {
296    GGML_ASSERT(ne0 % 2 == 0);
297    const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
298    const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
299    const sycl::range<3>    grid_dims(1, n_blocks_y, nr);
300    const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
301
302    const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
303    // Add FP16 capability check if T could be sycl::half
304    if constexpr (std::is_same_v<T, sycl::half>) {
305        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
306    }
307    // launch kernel
308    if (freq_factors == nullptr) {
309        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
310            rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
311                                  corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
312        });
313    } else {
314        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
315            rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
316                                 corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1);
317        });
318    }
319}
320
321
322
323
324// rope vision
325template <typename T>
326static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1,
327                             const size_t s2, const int n_dims, const int nr, const int32_t * pos,
328                             const float freq_scale, const float freq_base, const float ext_factor,
329                             const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors,
330                             const mrope_sections sections, queue_ptr stream) {
331    GGML_ASSERT(ne0 % 2 == 0);
332    const sycl::range<3>    block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1);
333    const int               n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE));
334    const sycl::range<3>    grid_dims(1, n_blocks_y, nr);
335    const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims);
336
337    const float theta_scale = std::pow(freq_base, -2.0f / n_dims);
338    // Add FP16 capability check if T could be sycl::half
339    if constexpr (std::is_same_v<T, sycl::half>) {
340        dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
341    }
342    // launch kernel
343    if (freq_factors == nullptr) {
344        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
345            rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
346                                  corr_dims, theta_scale, freq_factors, sections, item_ct1);
347        });
348    } else {
349        stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) {
350            rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor,
351                                 corr_dims, theta_scale, freq_factors, sections, item_ct1);
352        });
353    }
354}
355
356inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
357
358    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
359    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
360    GGML_ASSERT(dst->src[0]->type == dst->type);
361    const int64_t ne00 = dst->src[0]->ne[0]; // head dims
362    const int64_t ne01 = dst->src[0]->ne[1]; // num heads
363    const int64_t ne02 = dst->src[0]->ne[2]; // num heads
364    const int64_t nr = ggml_nrows(dst->src[0]);
365
366    const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type);
367    const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type);
368
369
370    //const int n_past      = ((int32_t *) dst->op_params)[0];
371    const int n_dims      = ((int32_t *) dst->op_params)[1];
372    const int mode        = ((int32_t *) dst->op_params)[2];
373    //const int n_ctx       = ((int32_t *) dst->op_params)[3];
374    const int n_ctx_orig  = ((int32_t *) dst->op_params)[4];
375    mrope_sections sections;
376
377    // RoPE alteration for extended context
378    float freq_base;
379    float freq_scale;
380    float ext_factor;
381    float attn_factor;
382    float beta_fast;
383    float beta_slow;
384
385    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
386    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
387    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
388    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
389    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
390    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
391    memcpy(&sections.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
392
393    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
394    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
395    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
396    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
397
398    if (is_mrope) {
399        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
400    }
401
402    if (is_vision) {
403        GGML_ASSERT(n_dims == ne00/2);
404    }
405
406    const int32_t * pos = (const int32_t *) dst->src[1]->data;
407
408    const float * freq_factors = nullptr;
409    if (dst->src[2] != nullptr) {
410        freq_factors = (const float *) dst->src[2]->data;
411    }
412
413    rope_corr_dims corr_dims;
414    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
415
416    dpct::queue_ptr main_stream = ctx.stream();
417    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
418
419    // compute
420    if (is_neox) {
421        GGML_SYCL_DEBUG("%s: neox path\n", __func__);
422        if (dst->src[0]->type == GGML_TYPE_F32) {
423            rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
424                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
425        } else if (dst->src[0]->type == GGML_TYPE_F16) {
426            rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
427                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
428                           main_stream);
429        } else {
430            GGML_ABORT("fatal error");
431        }
432    } else if (is_mrope && !is_vision) {
433        GGML_SYCL_DEBUG("%s: mrope path\n", __func__);
434        if (dst->src[0]->type == GGML_TYPE_F16) {
435            rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01,
436                s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
437                freq_factors, sections, is_imrope, main_stream);
438        } else if (dst->src[0]->type == GGML_TYPE_F32) {
439            rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
440                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
441                             is_imrope, main_stream);
442        } else {
443            GGML_ABORT("Fatal error: Tensor type unsupported!");
444        }
445    } else if (is_vision) {
446        GGML_SYCL_DEBUG("%s: vision path\n", __func__);
447        if (dst->src[0]->type == GGML_TYPE_F16) {
448            rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01,
449                             s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
450                             freq_factors, sections, main_stream);
451        } else if (dst->src[0]->type == GGML_TYPE_F32) {
452            rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims,
453                             nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections,
454                             main_stream);
455        } else {
456            GGML_ABORT("Fatal error: Tensor type unsupported!");
457        }
458    } else {
459        GGML_SYCL_DEBUG("%s: norm path\n", __func__);
460        if (dst->src[0]->type == GGML_TYPE_F32) {
461            rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr,
462                           pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream);
463        } else if (dst->src[0]->type == GGML_TYPE_F16) {
464            rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02,
465                           n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors,
466                           main_stream);
467        } else {
468            GGML_ABORT("fatal error");
469        }
470    }
471}
472
473void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
474    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3);
475    ggml_sycl_op_rope(ctx, dst);
476}
477