summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp')
-rw-r--r--llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp76
1 files changed, 76 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp b/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp
new file mode 100644
index 0000000..845b484
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-sycl/repeat_back.cpp
@@ -0,0 +1,76 @@
+#include "repeat_back.hpp"
+
+#include "common.hpp"
+
+#define GGML_ASSERT_TENSOR_FITS_INT(t) \
+ GGML_ASSERT((t)->ne[0] < INT_MAX && (t)->ne[1] < INT_MAX && (t)->ne[2] < INT_MAX && (t)->ne[3] < INT_MAX)
+
+void ggml_sycl_op_repeat_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ const float * src0_dd = (const float *) dst->src[0]->data;
+ float * dst_dd = (float *) dst->data;
+
+ GGML_ASSERT_TENSOR_FITS_INT(dst);
+ GGML_ASSERT_TENSOR_FITS_INT(dst->src[0]);
+
+ const int ne0 = dst->ne[0], ne1 = dst->ne[1], ne2 = dst->ne[2], ne3 = dst->ne[3];
+ const int ne00 = dst->src[0]->ne[0], ne01 = dst->src[0]->ne[1], ne02 = dst->src[0]->ne[2],
+ ne03 = dst->src[0]->ne[3];
+
+ const int nr0 = ne00 / ne0;
+ const int nr1 = ne01 / ne1;
+ const int nr2 = ne02 / ne2;
+ const int nr3 = ne03 / ne3;
+
+ const int nb0 = dst->src[0]->nb[0];
+ const int nb1 = dst->src[0]->nb[1];
+ const int nb2 = dst->src[0]->nb[2];
+ const int nb3 = dst->src[0]->nb[3];
+
+ const char * base = (const char *) src0_dd;
+
+ const size_t total = (size_t) ne0 * ne1 * ne2 * ne3;
+ constexpr int BLOCK_SIZE = 256;
+ const int num_blocks = (total + BLOCK_SIZE - 1) / BLOCK_SIZE;
+
+ const float inv_ne0 = 1.0f / ne0;
+ const float inv_ne_01 = 1.0f / (ne0 * ne1);
+ const float inv_ne_012 = 1.0f / (ne0 * ne1 * ne2);
+ const int repeat_count = nr0 * nr1 * nr2 * nr3;
+
+ queue_ptr stream = ctx.stream();
+
+ stream->parallel_for(
+ sycl::nd_range<1>(sycl::range<1>(num_blocks * BLOCK_SIZE), sycl::range<1>(BLOCK_SIZE)),
+ [=](sycl::nd_item<1> item_ct1) {
+ const size_t i = item_ct1.get_global_linear_id();
+ if (i >= total) {
+ return;
+ }
+
+ const int i3 = (int) (i * inv_ne_012);
+ const int i2 = (int) (i * inv_ne_01) - i3 * ne2;
+ const int i1 = (int) (i * inv_ne0) - (int) (i * inv_ne_01) * ne1;
+ const int i0 = i - (int) (i * inv_ne0) * ne0;
+
+ int j0 = 0, j1 = 0, j2 = 0, j3 = 0;
+ float acc = 0.0f;
+
+ for (int j = 0; j < repeat_count; ++j) {
+ const float * ptr = (const float *) (base + (i0 + j0 * ne0) * nb0 + (i1 + j1 * ne1) * nb1 +
+ (i2 + j2 * ne2) * nb2 + (i3 + j3 * ne3) * nb3);
+ acc += *ptr;
+
+ int carry = (++j0 >= nr0);
+ j0 -= carry * nr0;
+ carry = (carry && (++j1 >= nr1));
+ j1 -= carry * nr1;
+ carry = (carry && (++j2 >= nr2));
+ j2 -= carry * nr2;
+ j3 += carry;
+ }
+ dst_dd[i] = acc;
+ });
+}