1#include "roll.hpp"
2#include "common.hpp"
3
4using namespace sycl;
5
6static inline int wrap_add(int i, int shift, int n) {
7
8 int s = i + shift;
9 return (s >= n) ? (s - n) : s;
10}
11
12static void kernel_roll_fused_i0_i1(
13 queue &q,
14 const float *src_d,
15 float *dst_d,
16 int ne0, int ne1, int ne2, int ne3,
17 int sh0, int sh1, int sh2, int sh3)
18{
19 if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
20
21
22 const int stride1 = ne0;
23 const int stride2 = ne0 * ne1;
24 const int stride3 = ne0 * ne1 * ne2;
25
26
27 const int shNe0 = (ne0 - sh0) % ne0;
28 const int shNe1 = (ne1 - sh1) % ne1;
29 const int shNe2 = (ne2 - sh2) % ne2;
30 const int shNe3 = (ne3 - sh3) % ne3;
31
32
33 const size_t g0 = (size_t) ne3;
34 const size_t g1 = (size_t) ne2;
35 const size_t g2 = (size_t) (ne1 * ne0);
36
37 const range<3> global{ g0, g1, g2 };
38
39 q.submit([&](handler &h) {
40 h.parallel_for(global, [=](id<3> idx) {
41 const int i3 = (int) idx[0];
42 const int i2 = (int) idx[1];
43
44 const int fused = (int) idx[2];
45 const int i1 = fused / ne0;
46 const int i0 = fused - i1 * ne0; // fused % ne0
47
48
49 const int idx_dst = i0
50 + i1 * stride1
51 + i2 * stride2
52 + i3 * stride3;
53
54
55 const int s0 = wrap_add(i0, shNe0, ne0);
56 const int s1 = wrap_add(i1, shNe1, ne1);
57 const int s2 = wrap_add(i2, shNe2, ne2);
58 const int s3 = wrap_add(i3, shNe3, ne3);
59
60 const int idx_src = s0
61 + s1 * stride1
62 + s2 * stride2
63 + s3 * stride3;
64
65 dst_d[idx_dst] = src_d[idx_src];
66 });
67 });
68}
69
70void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
71 GGML_ASSERT(dst->type == GGML_TYPE_F32);
72
73 const ggml_tensor *src = dst->src[0];
74 GGML_ASSERT(src && src->type == GGML_TYPE_F32);
75
76 const int ne0 = (int) dst->ne[0];
77 const int ne1 = (int) dst->ne[1];
78 const int ne2 = (int) dst->ne[2];
79 const int ne3 = (int) dst->ne[3];
80
81 const int32_t *params = (const int32_t *) dst->op_params;
82 int shift0 = params[0];
83 int shift1 = params[1];
84 int shift2 = params[2];
85 int shift3 = params[3];
86
87
88 if ((shift0 | shift1 | shift2 | shift3) == 0) {
89 const size_t nb = ggml_nbytes(src);
90 queue *q = ctx.stream();
91 SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
92 return;
93 }
94
95 auto norm = [](int sh, int n) -> int {
96 if (n <= 0) return 0;
97 sh %= n;
98 if (sh < 0) sh += n;
99 return sh;
100 };
101 shift0 = norm(shift0, ne0);
102 shift1 = norm(shift1, ne1);
103 shift2 = norm(shift2, ne2);
104 shift3 = norm(shift3, ne3);
105
106 try {
107 queue *q = ctx.stream();
108
109 const float *src_d = (const float *) src->data;
110 float *dst_d = (float *) dst->data;
111 GGML_ASSERT(src_d && dst_d);
112
113 kernel_roll_fused_i0_i1(
114 *q, src_d, dst_d,
115 ne0, ne1, ne2, ne3,
116 shift0, shift1, shift2, shift3
117 );
118 } catch (const std::exception &e) {
119 std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
120 throw;
121 }
122}