1#include <sycl/sycl.hpp>
2#include "common.hpp"
3#include "add-id.hpp"
4
5static void add_id_kernel(
6 const float* src0,
7 const float* src1,
8 const int32_t* src2,
9 float* dst,
10 int64_t ne0,
11 int64_t ne1,
12 size_t nb01,
13 size_t nb02,
14 size_t nb11,
15 size_t nb21,
16 sycl::nd_item<3> item_ct1) {
17 const int64_t i1 = item_ct1.get_group(2);
18 const int64_t i2 = item_ct1.get_group(1);
19
20 const int i11 =
21 *(const int32_t*)((const char*)src2 + i1 * sizeof(int32_t) + i2 * nb21);
22
23 const size_t nb1 = ne0 * sizeof(float);
24 const size_t nb2 = ne1 * nb1;
25
26 float* dst_row = (float*)((char*)dst + i1 * nb1 + i2 * nb2);
27 const float* src0_row =
28 (const float*)((const char*)src0 + i1 * nb01 + i2 * nb02);
29 const float* src1_row = (const float*)((const char*)src1 + i11 * nb11);
30
31 for (int64_t i0 = item_ct1.get_local_id(2); i0 < ne0;
32 i0 += item_ct1.get_local_range(2)) {
33 dst_row[i0] = src0_row[i0] + src1_row[i0];
34 }
35}
36
37void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
38 const ggml_tensor* src0 = dst->src[0];
39 const ggml_tensor* src1 = dst->src[1];
40 const ggml_tensor* src2 = dst->src[2];
41
42 GGML_TENSOR_TERNARY_OP_LOCALS
43
44 GGML_ASSERT(dst->type == GGML_TYPE_F32);
45 GGML_ASSERT(src0->type == GGML_TYPE_F32);
46 GGML_ASSERT(src1->type == GGML_TYPE_F32);
47 GGML_ASSERT(src2->type == GGML_TYPE_I32);
48
49 GGML_ASSERT(nb00 == sizeof(float));
50 GGML_ASSERT(nb10 == sizeof(float));
51 GGML_ASSERT(nb20 == sizeof(int32_t));
52
53 const float* src0_d = (const float*)src0->data;
54 const float* src1_d = (const float*)src1->data;
55 const int32_t* src2_d = (const int32_t*)src2->data;
56 float* dst_d = (float*)dst->data;
57
58 int threads = std::min((int)ne00, 768); // cols
59 ctx.stream()->parallel_for(
60 sycl::nd_range<3>(
61 sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
62 sycl::range<3>(1, 1, threads)),
63 [=](sycl::nd_item<3> item_ct1) {
64 add_id_kernel(
65 src0_d,
66 src1_d,
67 src2_d,
68 dst_d,
69 ne0,
70 ne1,
71 nb01,
72 nb02,
73 nb11,
74 nb21,
75 item_ct1);
76 });
77}