1#include "pad_reflect_1d.cuh"
 2
 3static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void
 4    pad_reflect_1d_kernel_f32(
 5        const void * __restrict__ src0,
 6        void * __restrict__       dst,
 7        const int64_t             ne0,
 8        const int64_t             ne00,
 9        const uint3               ne01,
10        const int64_t             ne02,
11        const int64_t             ne03,
12        const int64_t             nb00,
13        const int64_t             nb01,
14        const int64_t             nb02,
15        const int64_t             nb03,
16        const int64_t             nb0,
17        const int64_t             nb1,
18        const int64_t             nb2,
19        const int64_t             nb3,
20        const int                 p0,
21        const int                 p1) {
22    const int64_t i3 = blockIdx.z;
23    const int64_t i2 = blockIdx.y;
24
25    const uint2   div_mod_packed = fast_div_modulo(blockIdx.x, ne01);
26    const int64_t tile1          = div_mod_packed.y;  // i1
27    const int64_t tile0          = div_mod_packed.x;  // nth i0 tile
28    const int64_t i1             = tile1;
29    const int64_t i0             = threadIdx.x + tile0 * blockDim.x;
30
31    // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh)
32    if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) {
33        return;
34    }
35
36    const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
37    char *       dst_ptr  = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
38
39    const int64_t rel_i0 = i0 - p0;  // relative i0 in src0
40    int64_t src_idx;
41
42    if (rel_i0 < 0) {
43        // Left padding - reflect
44        src_idx = -rel_i0;
45    } else if (rel_i0 < ne00) {
46        // Middle - copy
47        src_idx = rel_i0;
48    } else {
49        // Right padding - reflect
50        src_idx = 2 * ne00 - 2 - rel_i0;
51    }
52    const float value               = *(const float *) (src0_ptr + src_idx * nb00);
53    *(float *) (dst_ptr + i0 * nb0) = value;
54
55    GGML_UNUSED(p1);
56}
57
58void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
59    const ggml_tensor * src0   = dst->src[0];
60    cudaStream_t        stream = ctx.stream();
61
62    GGML_ASSERT(src0->type == GGML_TYPE_F32);
63    GGML_ASSERT(dst->type == GGML_TYPE_F32);
64
65    const int32_t * opts = (const int32_t *) dst->op_params;
66    const int       p0   = opts[0];
67    const int       p1   = opts[1];
68
69    const int64_t ne00        = src0->ne[0];
70    const int64_t ne01        = src0->ne[1];
71    const uint3   ne01_packed = init_fastdiv_values(ne01);
72    const int64_t ne02        = src0->ne[2];
73    const int64_t ne03        = src0->ne[3];
74
75    const int64_t ne0 = dst->ne[0];
76
77    // sanity: padded length matches
78    GGML_ASSERT(ne0 == ne00 + p0 + p1);
79
80    constexpr int64_t bx     = CUDA_PAD_REFLECT_1D_BLOCK_SIZE;  // threads per block (x)
81    const int64_t     tiles0 = (ne0 + bx - 1) / bx;             // number of tiles along i0
82    // grid.x covers i1 and all tiles of i0: [ne01 * tiles0]
83    // grid.y covers i2: [ne02]
84    // grid.z covers i3: [ne03]
85    const dim3        grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03);
86    const dim3        block_dims((unsigned) bx, 1, 1);
87
88    pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
89        src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
90        dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1);
91}