1#include "roll.hpp"
  2#include "common.hpp"
  3
  4using namespace sycl;
  5
  6static inline int wrap_add(int i, int shift, int n) {
  7
  8    int s = i + shift;
  9    return (s >= n) ? (s - n) : s;
 10}
 11
 12static void kernel_roll_fused_i0_i1(
 13    queue &q,
 14    const float *src_d,
 15    float *dst_d,
 16    int ne0, int ne1, int ne2, int ne3,
 17    int sh0, int sh1, int sh2, int sh3)
 18{
 19    if (ne0 == 0 || ne1 == 0 || ne2 == 0 || ne3 == 0) return;
 20
 21
 22    const int stride1 = ne0;
 23    const int stride2 = ne0 * ne1;
 24    const int stride3 = ne0 * ne1 * ne2;
 25
 26
 27    const int shNe0 = (ne0 - sh0) % ne0;
 28    const int shNe1 = (ne1 - sh1) % ne1;
 29    const int shNe2 = (ne2 - sh2) % ne2;
 30    const int shNe3 = (ne3 - sh3) % ne3;
 31
 32
 33    const size_t g0 = (size_t) ne3;
 34    const size_t g1 = (size_t) ne2;
 35    const size_t g2 = (size_t) (ne1 * ne0);
 36
 37    const range<3> global{ g0, g1, g2 };
 38
 39    q.submit([&](handler &h) {
 40        h.parallel_for(global, [=](id<3> idx) {
 41            const int i3 = (int) idx[0];
 42            const int i2 = (int) idx[1];
 43
 44            const int fused = (int) idx[2];
 45            const int i1 = fused / ne0;
 46            const int i0 = fused - i1 * ne0;  // fused % ne0
 47
 48
 49            const int idx_dst = i0
 50                              + i1 * stride1
 51                              + i2 * stride2
 52                              + i3 * stride3;
 53
 54
 55            const int s0 = wrap_add(i0, shNe0, ne0);
 56            const int s1 = wrap_add(i1, shNe1, ne1);
 57            const int s2 = wrap_add(i2, shNe2, ne2);
 58            const int s3 = wrap_add(i3, shNe3, ne3);
 59
 60            const int idx_src = s0
 61                              + s1 * stride1
 62                              + s2 * stride2
 63                              + s3 * stride3;
 64
 65            dst_d[idx_dst] = src_d[idx_src];
 66        });
 67    });
 68}
 69
 70void ggml_sycl_roll(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 71    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 72
 73    const ggml_tensor *src = dst->src[0];
 74    GGML_ASSERT(src && src->type == GGML_TYPE_F32);
 75
 76    const int ne0 = (int) dst->ne[0];
 77    const int ne1 = (int) dst->ne[1];
 78    const int ne2 = (int) dst->ne[2];
 79    const int ne3 = (int) dst->ne[3];
 80
 81    const int32_t *params = (const int32_t *) dst->op_params;
 82    int shift0 = params[0];
 83    int shift1 = params[1];
 84    int shift2 = params[2];
 85    int shift3 = params[3];
 86
 87
 88    if ((shift0 | shift1 | shift2 | shift3) == 0) {
 89        const size_t nb = ggml_nbytes(src);
 90        queue *q = ctx.stream();
 91        SYCL_CHECK(CHECK_TRY_ERROR(q->memcpy(dst->data, src->data, nb)));
 92        return;
 93    }
 94
 95    auto norm = [](int sh, int n) -> int {
 96        if (n <= 0) return 0;
 97        sh %= n;
 98        if (sh < 0) sh += n;
 99        return sh;
100    };
101    shift0 = norm(shift0, ne0);
102    shift1 = norm(shift1, ne1);
103    shift2 = norm(shift2, ne2);
104    shift3 = norm(shift3, ne3);
105
106    try {
107        queue *q = ctx.stream();
108
109        const float *src_d = (const float *) src->data;
110        float *dst_d = (float *) dst->data;
111        GGML_ASSERT(src_d && dst_d);
112
113        kernel_roll_fused_i0_i1(
114            *q, src_d, dst_d,
115            ne0, ne1, ne2, ne3,
116            shift0, shift1, shift2, shift3
117        );
118    } catch (const std::exception &e) {
119        std::fprintf(stderr, "[SYCL-ROLL] ERROR: %s\n", e.what());
120        throw;
121    }
122}