1#include "convert.cuh"
  2#include "ggml-cuda/common.cuh"
  3#include "ggml.h"
  4#include "rope.cuh"
  5
  6struct rope_corr_dims {
  7    float v[2];
  8};
  9
 10
 11struct mrope_sections {
 12    int v[4];
 13};
 14
 15static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
 16    const float y = (i0 / 2 - low) / max(0.001f, high - low);
 17    return 1.0f - min(1.0f, max(0.0f, y));
 18}
 19
 20// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 21// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 22template<bool forward>
 23static __device__ void rope_yarn(
 24        const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
 25        float mscale, float & cos_theta, float & sin_theta) {
 26    // Get n-d rotational scaling corrected for extrapolation
 27    float theta_interp = freq_scale * theta_extrap;
 28    float theta = theta_interp;
 29    if (ext_factor != 0.0f) {
 30        float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
 31        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 32
 33        // Get n-d magnitude scaling corrected for interpolation
 34        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
 35    }
 36    cos_theta = cosf(theta) * mscale;
 37    sin_theta = sinf(theta) * mscale;
 38    if (!forward) {
 39        sin_theta *= -1.0f;
 40    }
 41}
 42
 43template <bool forward, bool has_ff, typename T, typename D>
 44static __global__ void rope_norm(const T *            x,
 45                                 D *                  dst,
 46                                 const int            ne00,
 47                                 const int            ne01,
 48                                 const int            ne02,
 49                                 const int            s01,
 50                                 const int            s02,
 51                                 const int            s03,
 52                                 const int            s1,
 53                                 const int            s2,
 54                                 const int            s3,
 55                                 const int            n_dims,
 56                                 const int32_t *      pos,
 57                                 const float          freq_scale,
 58                                 const float          ext_factor,
 59                                 const float          attn_factor,
 60                                 const rope_corr_dims corr_dims,
 61                                 const float          theta_scale,
 62                                 const float *        freq_factors,
 63                                 const int64_t *      row_indices,
 64                                 const int            set_rows_stride) {
 65    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
 66
 67    if (i0 >= ne00) {
 68        return;
 69    }
 70
 71    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
 72
 73    const uint32_t i3 = row_dst / (ne01 * ne02);
 74    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
 75    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
 76
 77    int       idst = i0 + i1 * s1  + i2 * s2  + i3 * s3;
 78    const int ix   = i0 + i1 * s01 + i2 * s02 + i3 * s03;
 79    // Fusion optimization: ROPE + VIEW + SET_ROWS.
 80    // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
 81    if (set_rows_stride != 0) {
 82        idst = i1 * s1 + i0;
 83        idst += row_indices[i2] * set_rows_stride;
 84    }
 85
 86    const auto & store_coaelsced = [&](float x0, float x1) {
 87        if constexpr (std::is_same_v<float, D>) {
 88            float2 v = make_float2(x0, x1);
 89            ggml_cuda_memcpy_1<8>(dst + idst, &v);
 90        } else if constexpr (std::is_same_v<half, D>) {
 91            half2 v = make_half2(x0, x1);
 92            ggml_cuda_memcpy_1<4>(dst + idst, &v);
 93        }
 94    };
 95    if (i0 >= n_dims) {
 96        store_coaelsced(x[ix + 0], x[ix + 1]);
 97        return;
 98    }
 99
100    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
101
102    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
103
104    float cos_theta;
105    float sin_theta;
106
107    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
108
109    const float x0 = x[ix + 0];
110    const float x1 = x[ix + 1];
111
112    store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
113}
114
115template <bool forward, bool has_ff, typename T, typename D>
116static __global__ void rope_neox(const T *            x,
117                                 D *                  dst,
118                                 const int            ne00,
119                                 const int            ne01,
120                                 const int            ne02,
121                                 const int            s01,
122                                 const int            s02,
123                                 const int            s03,
124                                 const int            s1,
125                                 const int            s2,
126                                 const int            s3,
127                                 const int            n_dims,
128                                 const int32_t *      pos,
129                                 const float          freq_scale,
130                                 const float          ext_factor,
131                                 const float          attn_factor,
132                                 const rope_corr_dims corr_dims,
133                                 const float          theta_scale,
134                                 const float *        freq_factors,
135                                 const int64_t *      row_indices,
136                                 const int            set_rows_stride) {
137    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
138
139    if (i0 >= ne00) {
140        return;
141    }
142
143    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
144
145    const uint32_t i3 = row_dst / (ne01 * ne02);
146    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
147    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
148
149    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
150    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
151
152    // Fusion optimization: ROPE + VIEW + SET_ROWS.
153    // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
154    if (set_rows_stride != 0) {
155        idst = i1 * s1 + i0 / 2;
156        idst += row_indices[i2] * set_rows_stride;
157    }
158
159    if (i0 >= n_dims) {
160        dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
161        dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
162
163        return;
164    }
165
166    const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
167
168    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
169
170    float cos_theta;
171    float sin_theta;
172
173    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
174
175    const float x0 = x[ix + 0];
176    const float x1 = x[ix + n_dims/2];
177
178    dst[idst + 0]          = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
179    dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
180}
181
182template <bool forward, bool has_ff, typename T>
183static __global__ void rope_multi(const T *            x,
184                                  T *                  dst,
185                                  const int            ne00,
186                                  const int            ne01,
187                                  const int            ne02,
188                                  const int            s01,
189                                  const int            s02,
190                                  const int            s03,
191                                  const int            s1,
192                                  const int            s2,
193                                  const int            s3,
194                                  const int            n_dims,
195                                  const int32_t *      pos,
196                                  const float          freq_scale,
197                                  const float          ext_factor,
198                                  const float          attn_factor,
199                                  const rope_corr_dims corr_dims,
200                                  const float          theta_scale,
201                                  const float *        freq_factors,
202                                  const mrope_sections sections,
203                                  const bool           is_imrope) {
204    const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
205
206    if (i0 >= ne00) {
207        return;
208    }
209
210    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
211
212    const uint32_t i3 = row_dst / (ne01 * ne02);
213    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
214    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
215
216    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
217    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
218
219    if (i0 >= n_dims) {
220        dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
221        dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
222
223        return;
224    }
225
226    const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
227    const int sec_w = sections.v[1] + sections.v[0];
228    const int sector = (i0 / 2) % sect_dims;
229
230    float theta_base = 0.0;
231    if (is_imrope) {
232        if (sector % 3 == 1 && sector < 3 * sections.v[1]) {         // h
233            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
234        } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) {  // w
235            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
236        } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) {  // t
237            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
238        } else {
239            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
240        }
241    } else {
242        if (sector < sections.v[0]) {
243            theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
244        } else if (sector >= sections.v[0] && sector < sec_w) {
245            theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
246        } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
247            theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
248        } else if (sector >= sec_w + sections.v[2]) {
249            theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
250        }
251    }
252
253    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
254
255    float cos_theta;
256    float sin_theta;
257
258    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
259
260    const float x0 = x[ix + 0];
261    const float x1 = x[ix + n_dims/2];
262
263    dst[idst + 0]        = x0*cos_theta - x1*sin_theta;
264    dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
265}
266
267template <bool forward, bool has_ff, typename T>
268static __global__ void rope_vision(const T *            x,
269                                   T *                  dst,
270                                   const int            ne00,
271                                   const int            ne01,
272                                   const int            ne02,
273                                   const int            s01,
274                                   const int            s02,
275                                   const int            s03,
276                                   const int            s1,
277                                   const int            s2,
278                                   const int            s3,
279                                   const int            n_dims,
280                                   const int32_t *      pos,
281                                   const float          freq_scale,
282                                   const float          ext_factor,
283                                   const float          attn_factor,
284                                   const rope_corr_dims corr_dims,
285                                   const float          theta_scale,
286                                   const float *        freq_factors,
287                                   const mrope_sections sections) {
288    const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
289
290    if (i0 >= ne00) {
291        return;
292    }
293
294    const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
295
296    const uint32_t i3 = row_dst / (ne01 * ne02);
297    const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
298    const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
299
300    int       idst = i0 / 2 + i1 * s1  + i2 * s2  + i3 * s3;
301    const int ix   = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
302
303    const int sect_dims = sections.v[0] + sections.v[1];
304    const int sec_w     = sections.v[1] + sections.v[0];
305    const int sector    = (i0 / 2) % sect_dims;
306
307    float theta_base = 0.0;
308    if (sector < sections.v[0]) {
309        const int p = sector;
310        theta_base  = pos[i2] * powf(theta_scale, p);
311    } else if (sector >= sections.v[0] && sector < sec_w) {
312        const int p = sector - sections.v[0];
313        theta_base  = pos[i2 + ne02] * powf(theta_scale, p);
314    }
315
316    const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
317
318    float cos_theta;
319    float sin_theta;
320
321    rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
322
323    const float x0 = x[ix + 0];
324    const float x1 = x[ix + n_dims];
325
326    dst[idst + 0]      = x0*cos_theta - x1*sin_theta;
327    dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
328}
329
330template <bool forward, typename T, typename D>
331static void rope_norm_cuda(const T *            x,
332                           D *                  dst,
333                           const int            ne00,
334                           const int            ne01,
335                           const int            ne02,
336                           const int            s01,
337                           const int            s02,
338                           const int            s03,
339                           const int            s1,
340                           const int            s2,
341                           const int            s3,
342                           const int            n_dims,
343                           const int            nr,
344                           const int32_t *      pos,
345                           const float          freq_scale,
346                           const float          freq_base,
347                           const float          ext_factor,
348                           const float          attn_factor,
349                           const rope_corr_dims corr_dims,
350                           const float *        freq_factors,
351                           const int64_t *      row_indices,
352                           const int            set_rows_stride,
353                           cudaStream_t         stream) {
354    GGML_ASSERT(ne00 % 2 == 0);
355    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
356    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
357    const dim3 block_nums(nr, n_blocks_x, 1);
358
359    const float theta_scale = powf(freq_base, -2.0f / n_dims);
360
361    if (freq_factors == nullptr) {
362        rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
363            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
364            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
365    } else {
366        rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
367            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
368            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
369    }
370}
371
372template <bool forward, typename T, typename D>
373static void rope_neox_cuda(const T *            x,
374                           D *                  dst,
375                           const int            ne00,
376                           const int            ne01,
377                           const int            ne02,
378                           const int            s01,
379                           const int            s02,
380                           const int            s03,
381                           const int            s1,
382                           const int            s2,
383                           const int            s3,
384                           const int            n_dims,
385                           const int            nr,
386                           const int32_t *      pos,
387                           const float          freq_scale,
388                           const float          freq_base,
389                           const float          ext_factor,
390                           const float          attn_factor,
391                           const rope_corr_dims corr_dims,
392                           const float *        freq_factors,
393                           const int64_t *      row_indices,
394                           const int            set_rows_stride,
395                           cudaStream_t         stream) {
396    GGML_ASSERT(ne00 % 2 == 0);
397    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
398    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
399    const dim3 block_nums(nr, n_blocks_x, 1);
400
401    const float theta_scale = powf(freq_base, -2.0f / n_dims);
402
403    if (freq_factors == nullptr) {
404        rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
405            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
406            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
407    } else {
408        rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
409            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
410            attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
411    }
412}
413
414template <bool forward, typename T>
415static void rope_multi_cuda(const T *            x,
416                            T *                  dst,
417                            const int            ne00,
418                            const int            ne01,
419                            const int            ne02,
420                            const int            s01,
421                            const int            s02,
422                            const int            s03,
423                            const int            s1,
424                            const int            s2,
425                            const int            s3,
426                            const int            n_dims,
427                            const int            nr,
428                            const int32_t *      pos,
429                            const float          freq_scale,
430                            const float          freq_base,
431                            const float          ext_factor,
432                            const float          attn_factor,
433                            const rope_corr_dims corr_dims,
434                            const float *        freq_factors,
435                            const mrope_sections sections,
436                            const bool           is_imrope,
437                            cudaStream_t         stream) {
438    GGML_ASSERT(ne00 % 2 == 0);
439    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
440    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
441    const dim3 block_nums(nr, n_blocks_x, 1);
442
443    const float theta_scale = powf(freq_base, -2.0f / n_dims);
444
445    if (freq_factors == nullptr) {
446        rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
447            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
448            attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
449    } else {
450        rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
451            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
452            attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
453    }
454}
455
456template <bool forward, typename T>
457static void rope_vision_cuda(const T *            x,
458                             T *                  dst,
459                             const int            ne00,
460                             const int            ne01,
461                             const int            ne02,
462                             const int            s01,
463                             const int            s02,
464                             const int            s03,
465                             const int            s1,
466                             const int            s2,
467                             const int            s3,
468                             const int            n_dims,
469                             const int            nr,
470                             const int32_t *      pos,
471                             const float          freq_scale,
472                             const float          freq_base,
473                             const float          ext_factor,
474                             const float          attn_factor,
475                             const rope_corr_dims corr_dims,
476                             const float *        freq_factors,
477                             const mrope_sections sections,
478                             cudaStream_t         stream) {
479    GGML_ASSERT(ne00 % 2 == 0);
480    const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
481    const int  n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
482    const dim3 block_nums(nr, n_blocks_x, 1);
483    // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
484    // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
485
486    const float theta_scale = powf(freq_base, -2.0f/n_dims);
487
488    if (freq_factors == nullptr) {
489        rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
490            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
491            attn_factor, corr_dims, theta_scale, freq_factors, sections);
492    } else {
493        rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
494            x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
495            attn_factor, corr_dims, theta_scale, freq_factors, sections);
496    }
497}
498
499template <bool forward>
500void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
501                            ggml_tensor *               dst,
502                            const ggml_tensor *         set_rows = nullptr) {
503    const ggml_tensor * src0 = dst->src[0];
504    const ggml_tensor * src1 = dst->src[1];
505    const ggml_tensor * src2 = dst->src[2];
506
507    const float * src0_d = (const float *)src0->data;
508    const float * src1_d = (const float *)src1->data;
509
510    void *          dst_d           = dst->data;
511    const int64_t * row_indices     = nullptr;
512    ggml_type       dst_type        = dst->type;
513    int             set_rows_stride = 0;
514
515    if (set_rows != nullptr) {
516        GGML_ASSERT(forward);
517        dst_d           = set_rows->data;
518        row_indices     = (const int64_t *) set_rows->src[1]->data;
519        dst_type        = set_rows->type;
520        set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
521    }
522    cudaStream_t stream = ctx.stream();
523
524    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
525    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
526    // When not fused, src0 and dst types must match
527    // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
528    GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
529
530    const int64_t ne00 = src0->ne[0]; // head dims
531    const int64_t ne01 = src0->ne[1]; // num heads
532    const int64_t ne02 = src0->ne[2]; // num heads
533    const int64_t nr = ggml_nrows(src0);
534
535    const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
536    const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
537    const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
538
539    const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
540    const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
541    const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
542
543    //const int n_past     = ((int32_t *) dst->op_params)[0];
544    const int n_dims     = ((int32_t *) dst->op_params)[1];
545    const int mode       = ((int32_t *) dst->op_params)[2];
546    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
547    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
548    mrope_sections sections;
549
550    // RoPE alteration for extended context
551    float freq_base;
552    float freq_scale;
553    float ext_factor;
554    float attn_factor;
555    float beta_fast;
556    float beta_slow;
557
558    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
559    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
560    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
561    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
562    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
563    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
564    memcpy(&sections.v,  (int32_t *) dst->op_params + 11, sizeof(int)*4);
565
566    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
567    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
568    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
569    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
570
571    if (is_mrope) {
572        GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
573    }
574
575    if (is_vision) {
576        GGML_ASSERT(n_dims == ne00/2);
577    }
578
579    const int32_t * pos = (const int32_t *) src1_d;
580
581    const float * freq_factors = nullptr;
582    if (src2 != nullptr) {
583        freq_factors = (const float *) src2->data;
584    }
585
586    rope_corr_dims corr_dims;
587    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
588
589    // compute
590    if (is_neox) {
591        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
592            rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
593                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
594                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
595                                                  set_rows_stride, stream);
596        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
597            rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
598                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
599                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
600                                                 set_rows_stride, stream);
601        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
602            rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
603                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
604                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
605                                                set_rows_stride, stream);
606        } else {
607            GGML_ABORT("fatal error");
608        }
609    } else if (is_mrope && !is_vision) {
610        if (src0->type == GGML_TYPE_F32) {
611            rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
612                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
613                                     corr_dims, freq_factors, sections, is_imrope, stream);
614        } else if (src0->type == GGML_TYPE_F16) {
615            rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
616                                     s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
617                                     corr_dims, freq_factors, sections, is_imrope, stream);
618        } else {
619            GGML_ABORT("fatal error");
620        }
621    } else if (is_vision) {
622        if (src0->type == GGML_TYPE_F32) {
623            rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
624                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
625                                      corr_dims, freq_factors, sections, stream);
626        } else if (src0->type == GGML_TYPE_F16) {
627            rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
628                                      s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
629                                      corr_dims, freq_factors, sections, stream);
630        } else {
631            GGML_ABORT("fatal error");
632        }
633    } else {
634        if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
635            rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
636                                                  s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
637                                                  ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
638                                                  set_rows_stride, stream);
639        } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
640            rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
641                                                 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
642                                                 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
643                                                 set_rows_stride, stream);
644        } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
645            rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
646                                                s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
647                                                ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
648                                                set_rows_stride, stream);
649        } else {
650            GGML_ABORT("fatal error");
651        }
652    }
653}
654
655void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
656    ggml_cuda_op_rope_impl<true>(ctx, dst);
657}
658
659void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
660    ggml_cuda_op_rope_impl<false>(ctx, dst);
661}
662
663void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
664    ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
665}