1#include "common.hpp"
   2#include "ggml-sycl/presets.hpp"
   3#include "ggml.h"
   4#include "element_wise.hpp"
   5
   6#define SYCL_GLOBAL_ID_LOOP(K, ITEM) \
   7    for (auto i = ITEM.get_global_id(0); i < (size_t)K; i += ITEM.get_global_range(0))
   8
   9#define SYCL_LOCAL_ID_CALC(ITEM, IDX) \
  10    (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX))
  11
  12
  13static void acc_f32(const float * x, const float * y, float * dst, const int ne,
  14    const int ne10, const int ne11, const int ne12,
  15    const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) {
  16    const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0);
  17    if (i >= ne) {
  18        return;
  19    }
  20    int src1_idx = i - offset;
  21    int oz = src1_idx / nb2;
  22    int oy = (src1_idx - (oz * nb2)) / nb1;
  23    int ox = src1_idx % nb1;
  24    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
  25        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
  26    } else {
  27        dst[i] = x[i];
  28    }
  29}
  30
  31/* Unary OP funcs */
  32template<typename T>
  33static __dpct_inline__ T op_sgn(T x) {
  34    return x > static_cast<T>(0.f) ? static_cast<T>(1.f) : ((x < static_cast<T>(0.f) ? static_cast<T>(-1.f) : static_cast<T>(0.f)));
  35}
  36
  37template<typename T>
  38static __dpct_inline__ T op_abs(T x) {
  39    return sycl::fabs(x);
  40}
  41
  42template<typename T>
  43static __dpct_inline__ T op_elu(T x) {
  44    return (x > static_cast<T>(0.f)) ? x : sycl::expm1(x);
  45}
  46
  47template<typename T>
  48static __dpct_inline__ T op_gelu(T x) {
  49    const T GELU_COEF_A    = static_cast<T>(0.044715f);
  50    const T SQRT_2_OVER_PI = static_cast<T>(0.79788456080286535587989211986876f);
  51    return static_cast<T>(0.5f) * x *
  52           (static_cast<T>(1.0f) +
  53            sycl::tanh(SQRT_2_OVER_PI * x * (static_cast<T>(1.0f) + GELU_COEF_A * x * x)));
  54}
  55
  56template<typename T>
  57static __dpct_inline__ T op_silu(T x) {
  58    return x / (static_cast<T>(1.0f) + sycl::native::exp(-x));
  59}
  60
  61template<typename T>
  62static __dpct_inline__ T op_gelu_quick(T x) {
  63    const T GELU_QUICK_COEF_LOCAL = static_cast<T>(-1.702f);
  64    return x * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x)));
  65}
  66
  67template<typename T>
  68static __dpct_inline__ T op_gelu_erf(T x) {
  69    const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
  70    return static_cast<T>(0.5f) * x * (static_cast<T>(1.0f) + sycl::erf(x * SQRT_2_INV));
  71}
  72
  73template<typename T>
  74static __dpct_inline__ T op_tanh(T x) {
  75    return sycl::tanh(x);
  76}
  77
  78template<typename T>
  79static __dpct_inline__ T op_relu(T x) {
  80    return sycl::fmax(x, static_cast<T>(0));
  81}
  82
  83template<typename T>
  84static __dpct_inline__ T op_sigmoid(T x) {
  85    return static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(-x));
  86}
  87
  88template<typename T>
  89static __dpct_inline__ T op_sqrt(T x) {
  90    return sycl::sqrt(x);
  91}
  92
  93template<typename T>
  94static __dpct_inline__ T op_sin(T x) {
  95    return sycl::sin(x);
  96}
  97
  98template<typename T>
  99static __dpct_inline__ T op_cos(T x) {
 100    return sycl::cos(x);
 101}
 102
 103template<typename T>
 104static __dpct_inline__ T op_hardsigmoid(T x) {
 105    return sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
 106}
 107
 108template<typename T>
 109static __dpct_inline__ T op_hardswish(T x) {
 110    return x * sycl::fmin(static_cast<T>(1.0f), sycl::fmax(static_cast<T>(0.0f), (x + static_cast<T>(3.0f)) / static_cast<T>(6.0f)));
 111}
 112
 113template<typename T>
 114static __dpct_inline__ T op_exp(T x) {
 115    return sycl::exp(x);
 116}
 117
 118template<typename T>
 119static __dpct_inline__ T op_log(T x) {
 120    if (x <= static_cast<T>(0)) {
 121        return neg_infinity<T>();
 122    }
 123    return sycl::log(x);
 124}
 125
 126template<typename T>
 127static __dpct_inline__ T op_softplus(T x) {
 128    const float xf = (float) x;
 129    const float ax = sycl::fabs(xf);
 130    const float m  = sycl::fmax(xf, 0.0f);
 131    const float y  = m + sycl::log1p(sycl::exp(-ax));
 132    return (T) y;
 133}
 134
 135template<typename T>
 136static __dpct_inline__ T op_neg(T x) {
 137    return -x;
 138}
 139
 140template<typename T>
 141static __dpct_inline__ T op_step(T x) {
 142    return (x > static_cast<T>(0.0f)) ? static_cast<T>(1.0f) : static_cast<T>(0.0f);
 143}
 144
 145template<typename T>
 146static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) {
 147    T neg_slope_T = static_cast<T>(negative_slope);
 148    return sycl::fmax(x, static_cast<T>(0)) +
 149           sycl::fmin(x, static_cast<T>(0.0f)) * neg_slope_T;
 150}
 151
 152template<typename T>
 153static __dpct_inline__ T op_sqr(T x) {
 154    return x * x;
 155}
 156
 157template<typename T>
 158static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) {
 159    return x < static_cast<T>(min_val) ? static_cast<T>(min_val) : (x > static_cast<T>(max_val) ? static_cast<T>(max_val) : x);
 160}
 161
 162template<typename T>
 163static __dpct_inline__ T op_floor(T x) {
 164    return sycl::floor(x);
 165}
 166
 167template<typename T>
 168static __dpct_inline__ T op_ceil(T x) {
 169    return sycl::ceil(x);
 170}
 171
 172template<typename T>
 173static __dpct_inline__ T op_round(T x) {
 174    return sycl::round(x);
 175}
 176
 177template<typename T>
 178static __dpct_inline__ T op_trunc(T x) {
 179    return sycl::trunc(x);
 180}
 181
 182template<typename T, typename F>
 183static void unary_op_generic_kernel(
 184        const T * x,
 185        T * dst,
 186        const int k,
 187        const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3,
 188        const size_t nb0,  const size_t nb1,  const size_t nb2,  const size_t nb3,
 189        const size_t nbd0, const size_t nbd1, const size_t nbd2, const size_t nbd3,
 190        const sycl::nd_item<1> & item_ct1,
 191        F func) {
 192
 193        (void) ne3;
 194    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 195        const int64_t i0 =  i % ne0;
 196        const int64_t i1 = (i / ne0)        % ne1;
 197        const int64_t i2 = (i / (ne0*ne1))  % ne2;
 198        const int64_t i3 =  i / (ne0*ne1*ne2);
 199
 200        const char * src_base = (const char *) x;
 201        char       * dst_base = (char *) dst;
 202
 203        const T * srcp = (const T *)(src_base + i0*nb0  + i1*nb1  + i2*nb2  + i3*nb3 );
 204        T *       dstp = (T *)(dst_base + i0*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3);
 205
 206        *dstp = func(*srcp);
 207    }
 208}
 209
 210template<typename T>
 211static void unary_op_sqrt_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 212    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 213        dst[i] = op_sqrt(x[i]);
 214    }
 215}
 216
 217template<typename T>
 218static void unary_op_sin_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 219    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 220        dst[i] = op_sin(x[i]);
 221    }
 222}
 223
 224template<typename T>
 225static void unary_op_cos_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 226    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 227        dst[i] = op_cos(x[i]);
 228    }
 229}
 230
 231template<typename T>
 232static void unary_op_log_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 233    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 234        dst[i] = op_log(x[i]);
 235    }
 236}
 237
 238
 239template<typename T>
 240static void unary_op_leaky_relu_kernel(const T * x, T * dst, const int k, float negative_slope, const sycl::nd_item<1> &item_ct1) {
 241    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 242        dst[i] = op_leaky_relu(x[i], negative_slope);
 243    }
 244}
 245
 246template<typename T>
 247static void unary_op_sqr_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 248    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 249        dst[i] = op_sqr(x[i]);
 250    }
 251}
 252
 253template<typename T>
 254static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1, float min_val, float max_val) {
 255    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 256        dst[i] = op_clamp(x[i], min_val, max_val);
 257    }
 258}
 259
 260template<typename T>
 261static void unary_op_floor_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 262    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 263        dst[i] = op_floor(x[i]);
 264    }
 265}
 266
 267template<typename T>
 268static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 269    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 270        dst[i] = op_ceil(x[i]);
 271    }
 272}
 273
 274template<typename T>
 275static void unary_op_round_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 276    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 277        dst[i] = op_round(x[i]);
 278    }
 279}
 280
 281template<typename T>
 282static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
 283    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 284        dst[i] = op_trunc(x[i]);
 285    }
 286}
 287
 288template<typename  T>
 289static void upscale(const T  *x, T *dst, const int nb00, const int nb01,
 290                        const int nb02, const int nb03, const int ne10, const int ne11,
 291                        const int ne12, const int ne13, const float sf0, const float sf1,
 292                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
 293    int index = item_ct1.get_local_id(0) +
 294               item_ct1.get_group(0) * item_ct1.get_local_range(0);
 295    if (index >= ne10 * ne11 * ne12 * ne13) {
 296        return;
 297    }
 298    // operation
 299    int i10 = index % ne10;
 300    int i11 = (index / ne10) % ne11;
 301    int i12 = (index / (ne10 * ne11)) % ne12;
 302    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
 303
 304    int i00 = static_cast<int>(i10 / sf0);
 305    int i01 = static_cast<int>(i11 / sf1);
 306    int i02 = static_cast<int>(i12 / sf2);
 307    int i03 = static_cast<int>(i13 / sf3);
 308
 309    dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
 310}
 311
 312template<typename T>
 313static void clamp(const T * x, T * dst, const float min, const float max, const int k,
 314                      const sycl::nd_item<1> &item_ct1) {
 315    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 316        dst[i] = x[i] < static_cast<T>(min) ? static_cast<T>(min) : (x[i] > static_cast<T>(max) ? static_cast<T>(max) : x[i]);
 317    }
 318}
 319
 320template<typename T>
 321static void gated_op_fused_geglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
 322    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 323        const int64_t j0 = (i / n) * o0 + (i % n);
 324        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
 325        dst[i] = op_gelu(x[j0]) * g[j1];
 326    }
 327}
 328
 329template<typename T>
 330static void gated_op_fused_reglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
 331    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 332        const int64_t j0 = (i / n) * o0 + (i % n);
 333        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
 334        dst[i] = op_relu(x[j0]) * g[j1];
 335    }
 336}
 337
 338template<typename T>
 339static void gated_op_fused_swiglu(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
 340    SYCL_GLOBAL_ID_LOOP(k, item_ct1)  {
 341        const int64_t j0 = (i / n) * o0 + (i % n);
 342        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
 343        dst[i] = op_silu(x[j0]) * g[j1];
 344    }
 345}
 346
 347template<typename T>
 348static void gated_op_fused_geglu_erf(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
 349    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 350        const int64_t j0 = (i / n) * o0 + (i % n);
 351        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
 352        dst[i] = op_gelu_erf(x[j0]) * g[j1];
 353    }
 354}
 355
 356template<typename T>
 357static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const uint64_t k, const uint64_t n, const uint64_t o0, const uint64_t o1, const sycl::nd_item<1> &item_ct1) {
 358    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 359        const int64_t j0 = (i / n) * o0 + (i % n);
 360        const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
 361        dst[i] = op_gelu_quick(x[j0]) * g[j1];
 362    }
 363}
 364
 365namespace ggml_sycl_detail {
 366static void acc_f32_sycl(const float *x, const float *y, float *dst,
 367                         const int n_elements, const int ne10, const int ne11,
 368                         const int ne12, const int nb1, const int nb2,
 369                         const int offset, queue_ptr stream) {
 370    int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE);
 371    stream->parallel_for(
 372        sycl::nd_range<1>(sycl::range<1>(num_blocks) *
 373                              sycl::range<1>(SYCL_ACC_BLOCK_SIZE),
 374                          sycl::range<1>(SYCL_ACC_BLOCK_SIZE)),
 375        [=](sycl::nd_item<1> item_ct1) {
 376            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
 377                    item_ct1);
 378        });
 379}
 380
 381template<typename T>
 382static void arange_kernel(T * dst, const int k, T start, T step,
 383                         const sycl::nd_item<1> &item_ct1) {
 384    SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
 385        dst[i] = start + static_cast<T>(i) * step;
 386    }
 387}
 388
 389template<typename T>
 390static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01,
 391                             const int nb02, const int nb03, const int ne10, const int ne11,
 392                             const int ne12, const int ne13, const float sf0, const float sf1,
 393                             const float sf2, const float sf3, queue_ptr stream) {
 394    int dst_size = ne10 * ne11 * ne12 * ne13;
 395    int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE);
 396    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
 397    stream->parallel_for(
 398        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
 399            upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
 400        });
 401}
 402
 403template<typename KernelInvoker, typename... Args>
 404static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 405#if defined (GGML_SYCL_F16)
 406    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
 407    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 408#else
 409    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
 410    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 411#endif
 412    GGML_ASSERT(dst->src[0]->type == dst->type);
 413    dpct::queue_ptr main_stream = ctx.stream();
 414    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 415    switch (dst->type) {
 416#if defined (GGML_SYCL_F16)
 417        case GGML_TYPE_F16:
 418            {
 419                auto data_pts = cast_data<sycl::half>(dst);
 420                kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
 421                break;
 422            }
 423#endif
 424        case GGML_TYPE_F32:
 425            {
 426                auto data_pts = cast_data<float>(dst);
 427                kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...);
 428                break;
 429            }
 430        default:
 431            GGML_ABORT("GGML tensor type not supported!\n");
 432    }
 433}
 434
 435template<typename KernelInvoker, typename... Args>
 436static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 437#if defined (GGML_SYCL_F16)
 438    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
 439    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 440#else
 441    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
 442    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 443#endif
 444    GGML_ASSERT(dst->src[0]->type == dst->type);
 445    dpct::queue_ptr main_stream = ctx.stream();
 446    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 447    const ggml_tensor * src0 = dst->src[0];
 448    const ggml_tensor * src1 = dst->src[1];
 449    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;;
 450    GGML_ASSERT(dst->ne[0] == nc);
 451    GGML_ASSERT(ggml_is_contiguous_1(dst->src[0]));
 452    GGML_ASSERT(ggml_is_contiguous(dst));
 453    const int32_t swapped = ((const int32_t *) dst->op_params)[1];
 454    void * src0_d = src0->data;
 455    void * src1_d = src1 ? src1->data : src0->data;
 456    const int64_t src0_o = src0->nb[1];
 457    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 458    void * dst_d = dst->data;
 459    if (src1) {
 460        GGML_ASSERT(ggml_is_contiguous_1(src1));
 461        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
 462        GGML_ASSERT(src1->ne[0] == nc);
 463        GGML_ASSERT(src0->type == src1->type);
 464    }
 465    switch (dst->type) {
 466#if defined (GGML_SYCL_F16)
 467        case GGML_TYPE_F16:
 468            {
 469                sycl::half * src0_p = (sycl::half *) src0_d;
 470                sycl::half * src1_p = (sycl::half *) src1_d;
 471
 472                    if (!src1) {
 473                        src0_p += swapped ? nc : 0;
 474                        src1_p += swapped ? 0 : nc;
 475                    }
 476                kernel_invoker(src0_p,
 477                               src1_p,
 478                               (sycl::half *) dst_d,
 479                               ggml_nelements(dst),
 480                               nc,
 481                               src0_o / sizeof(sycl::half),
 482                               src1_o / sizeof(sycl::half),
 483                               main_stream,
 484                               std::forward<Args>(args)...);
 485                break;
 486            }
 487#endif
 488        case GGML_TYPE_F32:
 489            {
 490                float * src0_p = (float *) src0_d;
 491                float * src1_p = (float *) src1_d;
 492
 493                    if (!src1) {
 494                        src0_p += swapped ? nc : 0;
 495                        src1_p += swapped ? 0 : nc;
 496                    }
 497
 498                kernel_invoker(src0_p,
 499                               src1_p,
 500                               (float *) dst_d,
 501                               ggml_nelements(dst),
 502                               nc,
 503                               src0_o / sizeof(float),
 504                               src1_o / sizeof(float),
 505                               main_stream,
 506                               std::forward<Args>(args)...);
 507                break;
 508            }
 509        default:
 510            GGML_ABORT("GGML tensor type not supported!\n");
 511    }
 512}
 513
 514template<typename KernelInvoker, typename... Args>
 515static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) {
 516#if defined (GGML_SYCL_F16)
 517    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
 518    GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 519#else
 520    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
 521    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 522#endif
 523    GGML_ASSERT(dst->src[0]->type == dst->type);
 524
 525    dpct::queue_ptr main_stream = ctx.stream();
 526    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 527
 528    const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0];
 529    const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1];
 530    const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2];
 531    const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3];
 532    switch (dst->type) {
 533#if defined (GGML_SYCL_F16)
 534        case GGML_TYPE_F16:
 535            {
 536                auto data_pts = cast_data<sycl::half>(dst);
 537                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
 538                               (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
 539                               main_stream, std::forward<Args>(args)...);
 540                break;
 541            }
 542#endif
 543        case GGML_TYPE_F32:
 544            {
 545                auto data_pts = cast_data<float>(dst);
 546                kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2],
 547                               (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3,
 548                               main_stream, std::forward<Args>(args)...);
 549                break;
 550            }
 551        default:
 552            GGML_ABORT("GGML tensor type not supported!\n");
 553    }
 554}
 555
 556template<typename F>
 557static inline void ggml_sycl_op_unary(
 558        ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) {
 559
 560    ggml_tensor * src0 = dst->src[0];
 561
 562    const int64_t ne0  = dst->ne[0];
 563    const int64_t ne1  = dst->ne[1];
 564    const int64_t ne2  = dst->ne[2];
 565    const int64_t ne3  = dst->ne[3];
 566
 567    const size_t  nb0  = src0->nb[0];
 568    const size_t  nb1  = src0->nb[1];
 569    const size_t  nb2  = src0->nb[2];
 570    const size_t  nb3  = src0->nb[3];
 571
 572    const size_t  nbd0 = dst->nb[0];
 573    const size_t  nbd1 = dst->nb[1];
 574    const size_t  nbd2 = dst->nb[2];
 575    const size_t  nbd3 = dst->nb[3];
 576
 577    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 578        [=](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 579
 580            const int num_blocks = ceil_div(k_elements, 256);
 581
 582            stream->parallel_for(
 583                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
 584                                  sycl::range<1>(256)),
 585                [=](sycl::nd_item<1> item_ct1) {
 586                    unary_op_generic_kernel(
 587                        src, dst_ptr, k_elements,
 588                        ne0, ne1, ne2, ne3,
 589                        nb0, nb1, nb2, nb3,
 590                        nbd0, nbd1, nbd2, nbd3,
 591                        item_ct1,
 592                        func
 593                    );
 594                });
 595        });
 596}
 597
 598
 599static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 600    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 601    float start, stop, step;
 602    memcpy(&start, dst->op_params, sizeof(float));
 603    memcpy(&stop, (float *) dst->op_params + 1, sizeof(float));
 604    memcpy(&step, (float *) dst->op_params + 2, sizeof(float));
 605    dpct::queue_ptr stream = ctx.stream();
 606    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 607    float * dst_ptr = (float *)dst->data;
 608    const int k = (int)ggml_nelements(dst);
 609    const int num_blocks = ceil_div(k, SYCL_ARANGE_BLOCK_SIZE);
 610    stream->parallel_for(
 611        sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE),
 612                          sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)),
 613        [=](sycl::nd_item<1> item_ct1) {
 614            arange_kernel(dst_ptr, k, start, step, item_ct1);
 615        });
 616}
 617
 618} // namespace ggml_sycl_detail
 619
 620
 621
 622static inline void ggml_sycl_op_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 623    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 624        return op_sgn(x);
 625    });
 626}
 627
 628
 629static inline void ggml_sycl_op_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 630    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 631        return op_abs(x);
 632    });
 633}
 634
 635static inline void ggml_sycl_op_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 636    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 637        return op_elu(x);
 638    });
 639}
 640static inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 641    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 642        return op_silu(x);
 643    });
 644}
 645
 646static inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 647    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 648        return op_gelu(x);
 649    });
 650}
 651
 652static inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 653    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 654        return op_gelu_quick(x);
 655    });
 656}
 657
 658static inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 659    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 660        return op_gelu_erf(x);
 661    });
 662}
 663
 664static inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 665    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 666        return op_tanh(x);
 667    });
 668}
 669
 670static inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 671    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 672        return op_relu(x);
 673    });
 674}
 675
 676static inline void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 677    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 678        return op_hardsigmoid(x);
 679    });
 680}
 681
 682static inline void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 683    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 684        return op_hardswish(x);
 685    });
 686}
 687
 688static inline void ggml_sycl_op_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 689    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 690        return op_exp(x);
 691    });
 692}
 693
 694static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 695    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 696        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 697            const int num_blocks = ceil_div(k_elements, SYCL_EXP_BLOCK_SIZE); // Using EXP block size
 698            stream->parallel_for(
 699                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE),
 700                                  sycl::range<1>(SYCL_EXP_BLOCK_SIZE)),
 701                [=](sycl::nd_item<1> item_ct1) {
 702                    unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1);
 703                });
 704        });
 705}
 706
 707static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 708    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 709        return op_softplus(x);
 710    });
 711}
 712
 713static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 714    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 715        return op_neg(x);
 716    });
 717}
 718
 719
 720static inline void ggml_sycl_op_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 721    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 722        return op_step(x);
 723    });
 724}
 725
 726static inline void ggml_sycl_op_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 727    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 728        return op_sigmoid(x);
 729    });
 730}
 731
 732static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 733    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 734        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 735            const int num_blocks = ceil_div(k_elements, SYCL_SQRT_BLOCK_SIZE);
 736            stream->parallel_for(
 737                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE),
 738                                  sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)),
 739                [=](sycl::nd_item<1> item_ct1) {
 740                    unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1);
 741                });
 742        });
 743}
 744
 745static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 746    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 747        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 748            const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE);
 749            stream->parallel_for(
 750                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
 751                                  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
 752                [=](sycl::nd_item<1> item_ct1) {
 753                    unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1);
 754                });
 755        });
 756}
 757
 758static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 759    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 760        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 761            const int num_blocks = ceil_div(k_elements, SYCL_SIN_BLOCK_SIZE); // Using SIN block size
 762            stream->parallel_for(
 763                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE),
 764                                  sycl::range<1>(SYCL_SIN_BLOCK_SIZE)),
 765                [=](sycl::nd_item<1> item_ct1) {
 766                    unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1);
 767                });
 768        });
 769}
 770
 771static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 772    float negative_slope;
 773    memcpy(&negative_slope, dst->op_params, sizeof(float));
 774    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 775        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float slope) {
 776            const int num_blocks = ceil_div(k_elements, SYCL_RELU_BLOCK_SIZE);
 777            stream->parallel_for(
 778                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE),
 779                                  sycl::range<1>(SYCL_RELU_BLOCK_SIZE)),
 780                [=](sycl::nd_item<1> item_ct1) {
 781                    unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1);
 782                });
 783        }, negative_slope);
 784}
 785
 786static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 787    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 788        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 789            const int num_blocks = ceil_div(k_elements, SYCL_SQR_BLOCK_SIZE);
 790            stream->parallel_for(
 791                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE),
 792                                  sycl::range<1>(SYCL_SQR_BLOCK_SIZE)),
 793                [=](sycl::nd_item<1> item_ct1) {
 794                    unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1);
 795                });
 796        });
 797}
 798
 799static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 800    ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst,
 801        [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03,
 802           int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3,
 803           queue_ptr stream) {
 804            ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream);
 805        });
 806}
 807
 808static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 809    float min_val;
 810    float max_val;
 811    memcpy(&min_val, dst->op_params, sizeof(float));
 812    memcpy(&max_val, (float *) dst->op_params + 1, sizeof(float));
 813    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 814        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream, float min_arg, float max_arg) {
 815            const int num_blocks = ceil_div(k_elements, SYCL_CLAMP_BLOCK_SIZE);
 816            stream->parallel_for(
 817                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE),
 818                                  sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)),
 819                [=](sycl::nd_item<1> item_ct1) {
 820                    clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1);
 821                });
 822        }, min_val, max_val);
 823}
 824
 825static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 826    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 827        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 828            const int num_blocks = ceil_div(k_elements, 256);
 829            stream->parallel_for(
 830                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
 831                                  sycl::range<1>(256)),
 832                [=](sycl::nd_item<1> item_ct1) {
 833                    unary_op_floor_kernel(src, dst_ptr, k_elements, item_ct1);
 834                });
 835        });
 836}
 837
 838static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 839    ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
 840        return op_ceil(x);
 841    });
 842}
 843
 844static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 845    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 846        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 847            const int num_blocks = ceil_div(k_elements, 256);
 848            stream->parallel_for(
 849                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
 850                                  sycl::range<1>(256)),
 851                [=](sycl::nd_item<1> item_ct1) {
 852                    unary_op_round_kernel(src, dst_ptr, k_elements, item_ct1);
 853                });
 854        });
 855}
 856
 857static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 858    ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
 859        [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
 860            const int num_blocks = ceil_div(k_elements, 256);
 861            stream->parallel_for(
 862                sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
 863                                  sycl::range<1>(256)),
 864                [=](sycl::nd_item<1> item_ct1) {
 865                    unary_op_trunc_kernel(src, dst_ptr, k_elements, item_ct1);
 866                });
 867        });
 868}
 869
 870static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
 871    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
 872    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
 873    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 874    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
 875    dpct::queue_ptr main_stream = ctx.stream();
 876    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
 877    const float * src0_dd = static_cast<const float *>(dst->src[0]->data);
 878    const float * src1_dd = static_cast<const float*>(dst->src[1]->data);
 879    float *       dst_dd  = static_cast<float *>(dst->data);
 880
 881    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
 882    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
 883    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
 884    int offset = dst->op_params[3] / 4; // offset in bytes
 885
 886    ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream);
 887}
 888
 889static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 890    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
 891        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
 892            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
 893            main_stream->parallel_for(
 894                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
 895                gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
 896            });
 897        });
 898}
 899
 900static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 901    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
 902        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
 903            const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu
 904            main_stream->parallel_for(
 905                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
 906                gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
 907            });
 908        });
 909}
 910
 911static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 912    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
 913        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
 914            const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu
 915            main_stream->parallel_for(
 916                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
 917                gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
 918            });
 919        });
 920}
 921
 922__dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
 923    x = sycl::fmin(x, limit);
 924    g = sycl::fmax(sycl::fmin(g, limit), -limit);
 925
 926    float out_glu = x / (1.0f + sycl::native::exp(-x * alpha));
 927    out_glu = out_glu * (1.0f + g);
 928    return out_glu;
 929}
 930
 931
 932template <typename T>
 933static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k,
 934                              const int64_t n, const int64_t o0, const int64_t o1,
 935                              float alpha, float limit, sycl::nd_item<3> item_ct1) {
 936    const int64_t i = int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2);
 937
 938    if (i >= k) {
 939        return;
 940    }
 941
 942    const int64_t j0 = (i / n) * o0 + (i % n);
 943    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
 944
 945    float xi = x[j0];
 946    float gi = g[j1];
 947
 948    dst[i] = ggml_sycl_op_swiglu_oai_single(xi, gi, alpha, limit);
 949}
 950
 951template <typename T>
 952static void swiglu_oai_sycl(const T *       x,
 953                            const T *       g,
 954                            T *             dst,
 955                            const int64_t   k,
 956                            const int64_t   n,
 957                            const int64_t   o0,
 958                            const int64_t   o1,
 959                            const float     alpha,
 960                            const float     limit,
 961                            dpct::queue_ptr stream) {
 962    const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE;
 963    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE),
 964                                           sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)),
 965                         [=](sycl::nd_item<3> item_ct1) {
 966                             swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1);
 967                         });
 968}
 969
 970void ggml_sycl_op_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
 971    const ggml_tensor * src0 = dst->src[0];
 972    const ggml_tensor * src1 = dst->src[1];
 973    void * src0_d = src0->data;
 974    void * src1_d = src1 ? src1->data : src0->data;
 975    const int64_t src0_o = src0->nb[1];
 976    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 977    void * dst_d = dst->data;
 978    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 979    dpct::queue_ptr     stream = ctx.stream();
 980
 981    GGML_ASSERT(ggml_is_contiguous_1(src0));
 982    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
 983    GGML_ASSERT(ggml_is_contiguous(dst));
 984
 985    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 986    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 987    GGML_ASSERT(src0->type == dst->type);
 988    GGML_ASSERT(dst->ne[0] == nc);
 989    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
 990
 991    if (src1) {
 992        GGML_ASSERT(ggml_is_contiguous_1(src1));
 993        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
 994        GGML_ASSERT(src1->ne[0] == nc);
 995        GGML_ASSERT(src0->type == src1->type);
 996    }
 997
 998    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
 999    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
1000    const float alpha = ggml_get_op_params_f32(dst, 2);
1001    const float limit = ggml_get_op_params_f32(dst, 3);
1002
1003    float * src0_p = (float *) src0_d;
1004    float * src1_p = (float *) src1_d;
1005
1006    if (!src1) {
1007        src0_p += swapped ? nc : 0;
1008        src1_p += swapped ? 0 : nc;
1009    }
1010
1011    swiglu_oai_sycl(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
1012}
1013
1014static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1015    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1016        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1017            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1018            main_stream->parallel_for(
1019                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1020                gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1021            });
1022        });
1023}
1024
1025static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1026    ggml_sycl_detail::dispatch_ggml_sycl_op_fused_glu(ctx, dst,
1027        [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) {
1028            const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
1029            main_stream->parallel_for(
1030                    sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) {
1031                gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1);
1032            });
1033        });
1034}
1035
1036
1037void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1038    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1039    ggml_sycl_op_sqrt(ctx, dst);
1040}
1041
1042void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1043    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1044    ggml_sycl_op_sin(ctx, dst);
1045}
1046
1047void ggml_sycl_cos(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1048    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1049    ggml_sycl_op_cos(ctx, dst);
1050}
1051
1052void ggml_sycl_acc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1053    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
1054    ggml_sycl_op_acc(ctx, dst);
1055}
1056
1057void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1058    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1059    ggml_sycl_op_gelu(ctx, dst);
1060}
1061
1062void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1063    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1064    ggml_sycl_op_silu(ctx, dst);
1065}
1066
1067void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1068    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1069    ggml_sycl_op_gelu_quick(ctx, dst);
1070}
1071
1072void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1073    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1074    ggml_sycl_op_gelu_erf(ctx, dst);
1075}
1076
1077void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1078    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1079    ggml_sycl_op_tanh(ctx, dst);
1080}
1081
1082void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1083    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1084    ggml_sycl_op_relu(ctx, dst);
1085}
1086
1087void ggml_sycl_sigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1088    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1089    ggml_sycl_op_sigmoid(ctx, dst);
1090}
1091
1092void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1093    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1094    ggml_sycl_op_hardsigmoid(ctx, dst);
1095}
1096
1097void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1098    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1099    ggml_sycl_op_hardswish(ctx, dst);
1100}
1101
1102void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1103    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1104    ggml_sycl_op_exp(ctx, dst);
1105}
1106
1107void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1108    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1109    ggml_sycl_op_log(ctx, dst);
1110}
1111
1112void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1113    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1114    ggml_sycl_op_softplus(ctx, dst);
1115}
1116
1117void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1118    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1119    ggml_sycl_op_neg(ctx, dst);
1120}
1121
1122void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1123    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1124    ggml_sycl_op_step(ctx, dst);
1125}
1126
1127void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1128    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1129    ggml_sycl_op_leaky_relu(ctx, dst);
1130}
1131
1132void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1133    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1134    ggml_sycl_op_sqr(ctx, dst);
1135}
1136
1137void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1138    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1139    ggml_sycl_op_upscale(ctx, dst);
1140}
1141
1142
1143void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1144    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1145    ggml_sycl_op_clamp(ctx, dst);
1146}
1147
1148void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1149    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1150    ggml_sycl_op_sgn(ctx, dst);
1151}
1152
1153void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1154    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1155    ggml_sycl_op_abs(ctx, dst);
1156}
1157
1158void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1159    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1160    ggml_sycl_op_elu(ctx, dst);
1161}
1162
1163void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1164    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1165    ggml_sycl_op_geglu(ctx, dst);
1166}
1167
1168void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1169    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1170    ggml_sycl_op_reglu(ctx, dst);
1171}
1172
1173void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1174    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1175    ggml_sycl_op_swiglu(ctx, dst);
1176}
1177
1178void ggml_sycl_swiglu_oai(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1179    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1180    ggml_sycl_op_swiglu_oai(ctx, dst);
1181}
1182
1183void ggml_sycl_geglu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1184    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1185    ggml_sycl_op_geglu_erf(ctx, dst);
1186}
1187
1188void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1189    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1190    ggml_sycl_op_geglu_quick(ctx, dst);
1191}
1192
1193void ggml_sycl_arange(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1194    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0);
1195    ggml_sycl_detail::ggml_sycl_op_arange(ctx, dst);
1196}
1197
1198void ggml_sycl_floor(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1199    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1200    ggml_sycl_op_floor(ctx, dst);
1201}
1202
1203void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1204    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1205    ggml_sycl_op_ceil(ctx, dst);
1206}
1207
1208void ggml_sycl_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1209    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1210    ggml_sycl_op_round(ctx, dst);
1211}
1212
1213void ggml_sycl_trunc(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1214    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1215    ggml_sycl_op_trunc(ctx, dst);
1216}