1#include "convert.cuh"
2#include "ggml-cuda/common.cuh"
3#include "ggml.h"
4#include "rope.cuh"
5
6struct rope_corr_dims {
7 float v[2];
8};
9
10
11struct mrope_sections {
12 int v[4];
13};
14
15static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
16 const float y = (i0 / 2 - low) / max(0.001f, high - low);
17 return 1.0f - min(1.0f, max(0.0f, y));
18}
19
20// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
21// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
22template<bool forward>
23static __device__ void rope_yarn(
24 const float theta_extrap, const float freq_scale, const rope_corr_dims corr_dims, const int64_t i0, const float ext_factor,
25 float mscale, float & cos_theta, float & sin_theta) {
26 // Get n-d rotational scaling corrected for extrapolation
27 float theta_interp = freq_scale * theta_extrap;
28 float theta = theta_interp;
29 if (ext_factor != 0.0f) {
30 float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
31 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
32
33 // Get n-d magnitude scaling corrected for interpolation
34 mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
35 }
36 cos_theta = cosf(theta) * mscale;
37 sin_theta = sinf(theta) * mscale;
38 if (!forward) {
39 sin_theta *= -1.0f;
40 }
41}
42
43template <bool forward, bool has_ff, typename T, typename D>
44static __global__ void rope_norm(const T * x,
45 D * dst,
46 const int ne00,
47 const int ne01,
48 const int ne02,
49 const int s01,
50 const int s02,
51 const int s03,
52 const int s1,
53 const int s2,
54 const int s3,
55 const int n_dims,
56 const int32_t * pos,
57 const float freq_scale,
58 const float ext_factor,
59 const float attn_factor,
60 const rope_corr_dims corr_dims,
61 const float theta_scale,
62 const float * freq_factors,
63 const int64_t * row_indices,
64 const int set_rows_stride) {
65 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
66
67 if (i0 >= ne00) {
68 return;
69 }
70
71 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
72
73 const uint32_t i3 = row_dst / (ne01 * ne02);
74 const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
75 const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
76
77 int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
78 const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
79 // Fusion optimization: ROPE + VIEW + SET_ROWS.
80 // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
81 if (set_rows_stride != 0) {
82 idst = i1 * s1 + i0;
83 idst += row_indices[i2] * set_rows_stride;
84 }
85
86 const auto & store_coaelsced = [&](float x0, float x1) {
87 if constexpr (std::is_same_v<float, D>) {
88 float2 v = make_float2(x0, x1);
89 ggml_cuda_memcpy_1<8>(dst + idst, &v);
90 } else if constexpr (std::is_same_v<half, D>) {
91 half2 v = make_half2(x0, x1);
92 ggml_cuda_memcpy_1<4>(dst + idst, &v);
93 }
94 };
95 if (i0 >= n_dims) {
96 store_coaelsced(x[ix + 0], x[ix + 1]);
97 return;
98 }
99
100 const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
101
102 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
103
104 float cos_theta;
105 float sin_theta;
106
107 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
108
109 const float x0 = x[ix + 0];
110 const float x1 = x[ix + 1];
111
112 store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
113}
114
115template <bool forward, bool has_ff, typename T, typename D>
116static __global__ void rope_neox(const T * x,
117 D * dst,
118 const int ne00,
119 const int ne01,
120 const int ne02,
121 const int s01,
122 const int s02,
123 const int s03,
124 const int s1,
125 const int s2,
126 const int s3,
127 const int n_dims,
128 const int32_t * pos,
129 const float freq_scale,
130 const float ext_factor,
131 const float attn_factor,
132 const rope_corr_dims corr_dims,
133 const float theta_scale,
134 const float * freq_factors,
135 const int64_t * row_indices,
136 const int set_rows_stride) {
137 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
138
139 if (i0 >= ne00) {
140 return;
141 }
142
143 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
144
145 const uint32_t i3 = row_dst / (ne01 * ne02);
146 const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
147 const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
148
149 int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
150 const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
151
152 // Fusion optimization: ROPE + VIEW + SET_ROWS.
153 // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
154 if (set_rows_stride != 0) {
155 idst = i1 * s1 + i0 / 2;
156 idst += row_indices[i2] * set_rows_stride;
157 }
158
159 if (i0 >= n_dims) {
160 dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
161 dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
162
163 return;
164 }
165
166 const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
167
168 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
169
170 float cos_theta;
171 float sin_theta;
172
173 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
174
175 const float x0 = x[ix + 0];
176 const float x1 = x[ix + n_dims/2];
177
178 dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
179 dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
180}
181
182template <bool forward, bool has_ff, typename T>
183static __global__ void rope_multi(const T * x,
184 T * dst,
185 const int ne00,
186 const int ne01,
187 const int ne02,
188 const int s01,
189 const int s02,
190 const int s03,
191 const int s1,
192 const int s2,
193 const int s3,
194 const int n_dims,
195 const int32_t * pos,
196 const float freq_scale,
197 const float ext_factor,
198 const float attn_factor,
199 const rope_corr_dims corr_dims,
200 const float theta_scale,
201 const float * freq_factors,
202 const mrope_sections sections,
203 const bool is_imrope) {
204 const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
205
206 if (i0 >= ne00) {
207 return;
208 }
209
210 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
211
212 const uint32_t i3 = row_dst / (ne01 * ne02);
213 const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
214 const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
215
216 int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
217 const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
218
219 if (i0 >= n_dims) {
220 dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
221 dst[idst + i0/2 + 1] = x[ix + i0/2 + 1];
222
223 return;
224 }
225
226 const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3];
227 const int sec_w = sections.v[1] + sections.v[0];
228 const int sector = (i0 / 2) % sect_dims;
229
230 float theta_base = 0.0;
231 if (is_imrope) {
232 if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
233 theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
234 } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
235 theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
236 } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
237 theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
238 } else {
239 theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
240 }
241 } else {
242 if (sector < sections.v[0]) {
243 theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
244 } else if (sector >= sections.v[0] && sector < sec_w) {
245 theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
246 } else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
247 theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
248 } else if (sector >= sec_w + sections.v[2]) {
249 theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
250 }
251 }
252
253 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
254
255 float cos_theta;
256 float sin_theta;
257
258 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
259
260 const float x0 = x[ix + 0];
261 const float x1 = x[ix + n_dims/2];
262
263 dst[idst + 0] = x0*cos_theta - x1*sin_theta;
264 dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
265}
266
267template <bool forward, bool has_ff, typename T>
268static __global__ void rope_vision(const T * x,
269 T * dst,
270 const int ne00,
271 const int ne01,
272 const int ne02,
273 const int s01,
274 const int s02,
275 const int s03,
276 const int s1,
277 const int s2,
278 const int s3,
279 const int n_dims,
280 const int32_t * pos,
281 const float freq_scale,
282 const float ext_factor,
283 const float attn_factor,
284 const rope_corr_dims corr_dims,
285 const float theta_scale,
286 const float * freq_factors,
287 const mrope_sections sections) {
288 const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
289
290 if (i0 >= ne00) {
291 return;
292 }
293
294 const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
295
296 const uint32_t i3 = row_dst / (ne01 * ne02);
297 const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
298 const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
299
300 int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
301 const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
302
303 const int sect_dims = sections.v[0] + sections.v[1];
304 const int sec_w = sections.v[1] + sections.v[0];
305 const int sector = (i0 / 2) % sect_dims;
306
307 float theta_base = 0.0;
308 if (sector < sections.v[0]) {
309 const int p = sector;
310 theta_base = pos[i2] * powf(theta_scale, p);
311 } else if (sector >= sections.v[0] && sector < sec_w) {
312 const int p = sector - sections.v[0];
313 theta_base = pos[i2 + ne02] * powf(theta_scale, p);
314 }
315
316 const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
317
318 float cos_theta;
319 float sin_theta;
320
321 rope_yarn<forward>(theta_base/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, cos_theta, sin_theta);
322
323 const float x0 = x[ix + 0];
324 const float x1 = x[ix + n_dims];
325
326 dst[idst + 0] = x0*cos_theta - x1*sin_theta;
327 dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
328}
329
330template <bool forward, typename T, typename D>
331static void rope_norm_cuda(const T * x,
332 D * dst,
333 const int ne00,
334 const int ne01,
335 const int ne02,
336 const int s01,
337 const int s02,
338 const int s03,
339 const int s1,
340 const int s2,
341 const int s3,
342 const int n_dims,
343 const int nr,
344 const int32_t * pos,
345 const float freq_scale,
346 const float freq_base,
347 const float ext_factor,
348 const float attn_factor,
349 const rope_corr_dims corr_dims,
350 const float * freq_factors,
351 const int64_t * row_indices,
352 const int set_rows_stride,
353 cudaStream_t stream) {
354 GGML_ASSERT(ne00 % 2 == 0);
355 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
356 const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
357 const dim3 block_nums(nr, n_blocks_x, 1);
358
359 const float theta_scale = powf(freq_base, -2.0f / n_dims);
360
361 if (freq_factors == nullptr) {
362 rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
363 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
364 attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
365 } else {
366 rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
367 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
368 attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
369 }
370}
371
372template <bool forward, typename T, typename D>
373static void rope_neox_cuda(const T * x,
374 D * dst,
375 const int ne00,
376 const int ne01,
377 const int ne02,
378 const int s01,
379 const int s02,
380 const int s03,
381 const int s1,
382 const int s2,
383 const int s3,
384 const int n_dims,
385 const int nr,
386 const int32_t * pos,
387 const float freq_scale,
388 const float freq_base,
389 const float ext_factor,
390 const float attn_factor,
391 const rope_corr_dims corr_dims,
392 const float * freq_factors,
393 const int64_t * row_indices,
394 const int set_rows_stride,
395 cudaStream_t stream) {
396 GGML_ASSERT(ne00 % 2 == 0);
397 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
398 const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
399 const dim3 block_nums(nr, n_blocks_x, 1);
400
401 const float theta_scale = powf(freq_base, -2.0f / n_dims);
402
403 if (freq_factors == nullptr) {
404 rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
405 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
406 attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
407 } else {
408 rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
409 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
410 attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
411 }
412}
413
414template <bool forward, typename T>
415static void rope_multi_cuda(const T * x,
416 T * dst,
417 const int ne00,
418 const int ne01,
419 const int ne02,
420 const int s01,
421 const int s02,
422 const int s03,
423 const int s1,
424 const int s2,
425 const int s3,
426 const int n_dims,
427 const int nr,
428 const int32_t * pos,
429 const float freq_scale,
430 const float freq_base,
431 const float ext_factor,
432 const float attn_factor,
433 const rope_corr_dims corr_dims,
434 const float * freq_factors,
435 const mrope_sections sections,
436 const bool is_imrope,
437 cudaStream_t stream) {
438 GGML_ASSERT(ne00 % 2 == 0);
439 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
440 const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
441 const dim3 block_nums(nr, n_blocks_x, 1);
442
443 const float theta_scale = powf(freq_base, -2.0f / n_dims);
444
445 if (freq_factors == nullptr) {
446 rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
447 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
448 attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
449 } else {
450 rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
451 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
452 attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
453 }
454}
455
456template <bool forward, typename T>
457static void rope_vision_cuda(const T * x,
458 T * dst,
459 const int ne00,
460 const int ne01,
461 const int ne02,
462 const int s01,
463 const int s02,
464 const int s03,
465 const int s1,
466 const int s2,
467 const int s3,
468 const int n_dims,
469 const int nr,
470 const int32_t * pos,
471 const float freq_scale,
472 const float freq_base,
473 const float ext_factor,
474 const float attn_factor,
475 const rope_corr_dims corr_dims,
476 const float * freq_factors,
477 const mrope_sections sections,
478 cudaStream_t stream) {
479 GGML_ASSERT(ne00 % 2 == 0);
480 const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
481 const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
482 const dim3 block_nums(nr, n_blocks_x, 1);
483 // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
484 // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
485
486 const float theta_scale = powf(freq_base, -2.0f/n_dims);
487
488 if (freq_factors == nullptr) {
489 rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
490 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
491 attn_factor, corr_dims, theta_scale, freq_factors, sections);
492 } else {
493 rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
494 x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
495 attn_factor, corr_dims, theta_scale, freq_factors, sections);
496 }
497}
498
499template <bool forward>
500void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
501 ggml_tensor * dst,
502 const ggml_tensor * set_rows = nullptr) {
503 const ggml_tensor * src0 = dst->src[0];
504 const ggml_tensor * src1 = dst->src[1];
505 const ggml_tensor * src2 = dst->src[2];
506
507 const float * src0_d = (const float *)src0->data;
508 const float * src1_d = (const float *)src1->data;
509
510 void * dst_d = dst->data;
511 const int64_t * row_indices = nullptr;
512 ggml_type dst_type = dst->type;
513 int set_rows_stride = 0;
514
515 if (set_rows != nullptr) {
516 GGML_ASSERT(forward);
517 dst_d = set_rows->data;
518 row_indices = (const int64_t *) set_rows->src[1]->data;
519 dst_type = set_rows->type;
520 set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
521 }
522 cudaStream_t stream = ctx.stream();
523
524 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
525 GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
526 // When not fused, src0 and dst types must match
527 // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
528 GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
529
530 const int64_t ne00 = src0->ne[0]; // head dims
531 const int64_t ne01 = src0->ne[1]; // num heads
532 const int64_t ne02 = src0->ne[2]; // num heads
533 const int64_t nr = ggml_nrows(src0);
534
535 const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
536 const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
537 const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
538
539 const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
540 const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
541 const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
542
543 //const int n_past = ((int32_t *) dst->op_params)[0];
544 const int n_dims = ((int32_t *) dst->op_params)[1];
545 const int mode = ((int32_t *) dst->op_params)[2];
546 //const int n_ctx = ((int32_t *) dst->op_params)[3];
547 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
548 mrope_sections sections;
549
550 // RoPE alteration for extended context
551 float freq_base;
552 float freq_scale;
553 float ext_factor;
554 float attn_factor;
555 float beta_fast;
556 float beta_slow;
557
558 memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
559 memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
560 memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
561 memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
562 memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
563 memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
564 memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4);
565
566 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
567 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
568 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
569 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
570
571 if (is_mrope) {
572 GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0);
573 }
574
575 if (is_vision) {
576 GGML_ASSERT(n_dims == ne00/2);
577 }
578
579 const int32_t * pos = (const int32_t *) src1_d;
580
581 const float * freq_factors = nullptr;
582 if (src2 != nullptr) {
583 freq_factors = (const float *) src2->data;
584 }
585
586 rope_corr_dims corr_dims;
587 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v);
588
589 // compute
590 if (is_neox) {
591 if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
592 rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
593 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
594 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
595 set_rows_stride, stream);
596 } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
597 rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
598 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
599 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
600 set_rows_stride, stream);
601 } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
602 rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
603 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
604 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
605 set_rows_stride, stream);
606 } else {
607 GGML_ABORT("fatal error");
608 }
609 } else if (is_mrope && !is_vision) {
610 if (src0->type == GGML_TYPE_F32) {
611 rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
612 s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
613 corr_dims, freq_factors, sections, is_imrope, stream);
614 } else if (src0->type == GGML_TYPE_F16) {
615 rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
616 s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
617 corr_dims, freq_factors, sections, is_imrope, stream);
618 } else {
619 GGML_ABORT("fatal error");
620 }
621 } else if (is_vision) {
622 if (src0->type == GGML_TYPE_F32) {
623 rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
624 s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
625 corr_dims, freq_factors, sections, stream);
626 } else if (src0->type == GGML_TYPE_F16) {
627 rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
628 s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
629 corr_dims, freq_factors, sections, stream);
630 } else {
631 GGML_ABORT("fatal error");
632 }
633 } else {
634 if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
635 rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
636 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
637 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
638 set_rows_stride, stream);
639 } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
640 rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
641 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
642 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
643 set_rows_stride, stream);
644 } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
645 rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
646 s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
647 ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
648 set_rows_stride, stream);
649 } else {
650 GGML_ABORT("fatal error");
651 }
652 }
653}
654
655void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
656 ggml_cuda_op_rope_impl<true>(ctx, dst);
657}
658
659void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
660 ggml_cuda_op_rope_impl<false>(ctx, dst);
661}
662
663void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
664 ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
665}