1#include <sycl/sycl.hpp>
 2#include "common.hpp"
 3#include "add-id.hpp"
 4
 5static void add_id_kernel(
 6    const float* src0,
 7    const float* src1,
 8    const int32_t* src2,
 9    float* dst,
10    int64_t ne0,
11    int64_t ne1,
12    size_t nb01,
13    size_t nb02,
14    size_t nb11,
15    size_t nb21,
16    sycl::nd_item<3> item_ct1) {
17  const int64_t i1 = item_ct1.get_group(2);
18  const int64_t i2 = item_ct1.get_group(1);
19
20  const int i11 =
21      *(const int32_t*)((const char*)src2 + i1 * sizeof(int32_t) + i2 * nb21);
22
23  const size_t nb1 = ne0 * sizeof(float);
24  const size_t nb2 = ne1 * nb1;
25
26  float* dst_row = (float*)((char*)dst + i1 * nb1 + i2 * nb2);
27  const float* src0_row =
28      (const float*)((const char*)src0 + i1 * nb01 + i2 * nb02);
29  const float* src1_row = (const float*)((const char*)src1 + i11 * nb11);
30
31  for (int64_t i0 = item_ct1.get_local_id(2); i0 < ne0;
32       i0 += item_ct1.get_local_range(2)) {
33    dst_row[i0] = src0_row[i0] + src1_row[i0];
34  }
35}
36
37void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
38  const ggml_tensor* src0 = dst->src[0];
39  const ggml_tensor* src1 = dst->src[1];
40  const ggml_tensor* src2 = dst->src[2];
41
42  GGML_TENSOR_TERNARY_OP_LOCALS
43
44  GGML_ASSERT(dst->type == GGML_TYPE_F32);
45  GGML_ASSERT(src0->type == GGML_TYPE_F32);
46  GGML_ASSERT(src1->type == GGML_TYPE_F32);
47  GGML_ASSERT(src2->type == GGML_TYPE_I32);
48
49  GGML_ASSERT(nb00 == sizeof(float));
50  GGML_ASSERT(nb10 == sizeof(float));
51  GGML_ASSERT(nb20 == sizeof(int32_t));
52
53  const float* src0_d = (const float*)src0->data;
54  const float* src1_d = (const float*)src1->data;
55  const int32_t* src2_d = (const int32_t*)src2->data;
56  float* dst_d = (float*)dst->data;
57
58  int threads = std::min((int)ne00, 768);  // cols
59  ctx.stream()->parallel_for(
60      sycl::nd_range<3>(
61          sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads),
62          sycl::range<3>(1, 1, threads)),
63      [=](sycl::nd_item<3> item_ct1) {
64        add_id_kernel(
65            src0_d,
66            src1_d,
67            src2_d,
68            dst_d,
69            ne0,
70            ne1,
71            nb01,
72            nb02,
73            nb11,
74            nb21,
75            item_ct1);
76      });
77}