1#include "pad_reflect_1d.hpp"
  2
  3static void pad_reflect_1d_kernel_f32(
  4    const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
  5    const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
  6    const int64_t ne03, const int64_t nb00, const int64_t nb01,
  7    const int64_t nb02, const int64_t nb03, const int64_t nb0,
  8    const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
  9    const int p1, sycl::nd_item<3> item_ct1) {
 10
 11    const int64_t i3 = item_ct1.get_group(0);
 12    const int64_t i2 = item_ct1.get_group(1);
 13
 14    const sycl::uint2 div_mod_packed =
 15        fast_div_modulo(item_ct1.get_group(2), ne01);
 16    const int64_t tile1 = div_mod_packed.y();
 17    const int64_t tile0 = div_mod_packed.x();
 18    const int64_t i1 = tile1;
 19    const int64_t i0 =
 20        item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
 21
 22    if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
 23        return;
 24    }
 25
 26    const char *src0_ptr =
 27        (const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
 28    char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
 29
 30    const int64_t rel_i0 = i0 - p0; // relative i0 in src0
 31    int64_t src_idx;
 32
 33    if (rel_i0 < 0) {
 34        // Left padding - reflect
 35        src_idx = -rel_i0;
 36    } else if (rel_i0 < ne00) {
 37        // Middle - copy
 38        src_idx = rel_i0;
 39    } else {
 40        // Right padding - reflect
 41        src_idx = 2 * ne00 - 2 - rel_i0;
 42    }
 43    const float value = *(const float *)(src0_ptr + src_idx * nb00);
 44    *(float *)(dst_ptr + i0 * nb0) = value;
 45
 46    GGML_UNUSED(p1);
 47}
 48
 49void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
 50                                 ggml_tensor *dst) {
 51
 52    const ggml_tensor *src0 = dst->src[0];
 53    dpct::queue_ptr stream = ctx.stream();
 54
 55    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 56    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 57
 58    const int32_t *opts = (const int32_t *)dst->op_params;
 59    const int p0 = opts[0];
 60    const int p1 = opts[1];
 61
 62    const int64_t ne00 = src0->ne[0];
 63    const int64_t ne01 = src0->ne[1];
 64    const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
 65    const int64_t ne02 = src0->ne[2];
 66    const int64_t ne03 = src0->ne[3];
 67
 68    const int64_t ne0 = dst->ne[0];
 69
 70    GGML_ASSERT(ne0 == ne00 + p0 + p1);
 71
 72    constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
 73    const int64_t tiles0 = (ne0 + bx - 1) / bx;
 74    const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
 75                               (unsigned)ne03);
 76    const dpct::dim3 block_dims((unsigned)bx, 1, 1);
 77
 78    stream->submit([&](sycl::handler &cgh) {
 79        auto src0_data_ct0 = src0->data;
 80        auto dst_data_ct1 = dst->data;
 81        auto src0_nb_ct7 = src0->nb[0];
 82        auto src0_nb_ct8 = src0->nb[1];
 83        auto src0_nb_ct9 = src0->nb[2];
 84        auto src0_nb_ct10 = src0->nb[3];
 85        auto dst_nb_ct11 = dst->nb[0];
 86        auto dst_nb_ct12 = dst->nb[1];
 87        auto dst_nb_ct13 = dst->nb[2];
 88        auto dst_nb_ct14 = dst->nb[3];
 89
 90        cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
 91                         [=](sycl::nd_item<3> item_ct1) {
 92                             pad_reflect_1d_kernel_f32(
 93                                 src0_data_ct0, dst_data_ct1, ne0, ne00,
 94                                 ne01_packed, ne02, ne03, src0_nb_ct7,
 95                                 src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
 96                                 dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
 97                                 dst_nb_ct14, p0, p1, item_ct1);
 98                         });
 99    });
100}