1#include "binbcast.cuh"
  2#include <cstdint>
  3#include <utility>
  4
  5static __device__ __forceinline__ float op_repeat(const float a, const float b) {
  6    return b;
  7    GGML_UNUSED(a);
  8}
  9
 10static __device__ __forceinline__ float op_add(const float a, const float b) {
 11    return a + b;
 12}
 13
 14static __device__ __forceinline__ float op_sub(const float a, const float b) {
 15    return a - b;
 16}
 17
 18static __device__ __forceinline__ float op_mul(const float a, const float b) {
 19    return a * b;
 20}
 21
 22static __device__ __forceinline__ float op_div(const float a, const float b) {
 23    return a / b;
 24}
 25
 26template <float (*bin_op)(const float, const float),
 27          typename src0_t,
 28          typename src1_t,
 29          typename dst_t,
 30          typename... src1_ptrs>
 31static __global__ void k_bin_bcast(const src0_t *         src0,
 32                                   const src1_t *         src1,
 33                                   dst_t *                dst,
 34                                   const int              ne0,
 35                                   const int              ne1,
 36                                   const int              ne2,
 37                                   const uint3            ne3,
 38                                   const uint3            ne10,
 39                                   const uint3            ne11,
 40                                   const uint3            ne12,
 41                                   const uint3            ne13,
 42                                 /*const int              s0,*/
 43                                   const int              s1,
 44                                   const int              s2,
 45                                   const int              s3,
 46                                   const int              s00,
 47                                   const int              s01,
 48                                   const int              s02,
 49                                   const int              s03,
 50                                   const int              s10,
 51                                   const int              s11,
 52                                   const int              s12,
 53                                   const int              s13,
 54                                   src1_ptrs... src1s) {
 55    const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
 56    const uint32_t i1  = (blockDim.y * blockIdx.y + threadIdx.y);
 57    const uint32_t i2  = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
 58    const uint32_t i3  = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
 59
 60    if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
 61        return;
 62    }
 63
 64    const uint32_t i11 = fastmodulo(i1, ne11);
 65    const uint32_t i12 = fastmodulo(i2, ne12);
 66    const uint32_t i13 = fastmodulo(i3, ne13);
 67
 68    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
 69    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
 70    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
 71
 72    const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
 73    dst_t * dst_row = dst + i_dst;
 74
 75    for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
 76        const uint32_t i10 = fastmodulo(i0, ne10);
 77
 78        float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
 79        if constexpr (sizeof...(src1_ptrs) > 0) {
 80            result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
 81        } else {
 82            result = bin_op(result, (float)src1[i_src1 + i10*s10]);
 83        }
 84
 85        dst_row[i0] = (dst_t) result;
 86    }
 87}
 88
 89template <float (*bin_op)(const float, const float),
 90          typename src0_t,
 91          typename src1_t,
 92          typename dst_t,
 93          typename... src1_ptrs>
 94static __global__ void k_bin_bcast_unravel(const src0_t *         src0,
 95                                           const src1_t *         src1,
 96                                           dst_t *                dst,
 97                                           const uint3            ne0,
 98                                           const uint3            ne1,
 99                                           const uint3            ne2,
100                                           const uint32_t         ne3,
101                                           const uint3            prod_012,
102                                           const uint3            prod_01,
103                                           const uint3            ne10,
104                                           const uint3            ne11,
105                                           const uint3            ne12,
106                                           const uint3            ne13,
107                                         /*const int              s0,*/
108                                           const int              s1,
109                                           const int              s2,
110                                           const int              s3,
111                                           const int              s00,
112                                           const int              s01,
113                                           const int              s02,
114                                           const int              s03,
115                                           const int              s10,
116                                           const int              s11,
117                                           const int              s12,
118                                           const int              s13,
119                                           src1_ptrs... src1s) {
120    const int i = blockDim.x*blockIdx.x + threadIdx.x;
121
122    const uint32_t i3 = fastdiv(i, prod_012);
123    const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
124    const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
125    const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
126
127    if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
128        return;
129    }
130
131    const int i11 = fastmodulo(i1, ne11);
132    const int i12 = fastmodulo(i2, ne12);
133    const int i13 = fastmodulo(i3, ne13);
134
135    const size_t i_src0 =  i3*s03 +  i2*s02 +  i1*s01;
136    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
137    const size_t i_dst  =  i3*s3  +  i2*s2  +  i1*s1;
138
139    const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
140    dst_t * dst_row = dst + i_dst;
141
142    const int i10 = fastmodulo(i0, ne10);
143
144    float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
145    if constexpr (sizeof...(src1_ptrs) > 0) {
146        result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
147    } else {
148        result = bin_op(result, (float)src1[i_src1 + i10*s10]);
149    }
150
151    dst_row[i0] = (dst_t) result;
152}
153
154template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
155static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
156                                  const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
157                                  cudaStream_t stream, std::index_sequence<I...>) {
158    GGML_TENSOR_BINARY_OP_LOCALS
159
160    int nr0 = ne10 / ne0;
161    int nr1 = ne11 / ne1;
162    int nr2 = ne12 / ne2;
163    int nr3 = ne13 / ne3;
164
165    int nr[4] = { nr0, nr1, nr2, nr3 };
166
167    int64_t cne[]  = { ne0, ne1, ne2, ne3 };
168    int64_t cne0[] = { ne00, ne01, ne02, ne03 };
169    int64_t cne1[] = { ne10, ne11, ne12, ne13 };
170
171    size_t cnb[]  = { nb0, nb1, nb2, nb3 };
172    size_t cnb0[] = { nb00, nb01, nb02, nb03 };
173    size_t cnb1[] = { nb10, nb11, nb12, nb13 };
174
175    auto collapse = [](int64_t cne[]) {
176        cne[0] *= cne[1];
177        cne[1] = cne[2];
178        cne[2] = cne[3];
179        cne[3] = 1;
180    };
181
182    auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
183        cnb[1] *= cne[1];
184        cnb[2] *= cne[2];
185        cnb[3] *= cne[3];
186    };
187
188    if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
189        for (int i = 0; i < 4; i++) {
190            if (nr[i] != 1) {
191                break;
192            }
193            if (i > 0) {
194                collapse_nb(cnb, cne);
195                collapse_nb(cnb0, cne0);
196                collapse_nb(cnb1, cne1);
197                collapse(cne);
198                collapse(cne0);
199                collapse(cne1);
200            }
201        }
202    }
203
204    {
205        int64_t ne0 = cne[0];
206        int64_t ne1 = cne[1];
207        int64_t ne2 = cne[2];
208        int64_t ne3 = cne[3];
209
210        //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
211        //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
212        //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
213        //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
214
215        size_t nb0 = cnb[0];
216        size_t nb1 = cnb[1];
217        size_t nb2 = cnb[2];
218        size_t nb3 = cnb[3];
219
220        size_t nb00 = cnb0[0];
221        size_t nb01 = cnb0[1];
222        size_t nb02 = cnb0[2];
223        size_t nb03 = cnb0[3];
224
225        size_t nb10 = cnb1[0];
226        size_t nb11 = cnb1[1];
227        size_t nb12 = cnb1[2];
228        size_t nb13 = cnb1[3];
229
230      //size_t s0 = nb0 / sizeof(dst_t);
231        size_t s1 = nb1 / sizeof(dst_t);
232        size_t s2 = nb2 / sizeof(dst_t);
233        size_t s3 = nb3 / sizeof(dst_t);
234
235        size_t s10 = nb10 / sizeof(src1_t);
236        size_t s11 = nb11 / sizeof(src1_t);
237        size_t s12 = nb12 / sizeof(src1_t);
238        size_t s13 = nb13 / sizeof(src1_t);
239
240        size_t s00 = nb00 / sizeof(src0_t);
241        size_t s01 = nb01 / sizeof(src0_t);
242        size_t s02 = nb02 / sizeof(src0_t);
243        size_t s03 = nb03 / sizeof(src0_t);
244
245        GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
246        GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
247        GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
248        GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
249
250        GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
251        GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
252        GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
253        GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
254
255        GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
256        GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
257        GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
258        GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
259
260        const int block_size = 128;
261
262        int64_t hne0 = std::max(ne0 / 2LL, 1LL);
263
264        dim3 block_dims;
265        block_dims.x = std::min<unsigned int>(hne0, block_size);
266        block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
267        block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
268
269        dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
270                        (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
271
272        const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
273        const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
274        const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
275        const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
276
277        if (block_nums.z > 65535 || block_nums.y > 65535) {
278            int         block_num  = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
279            const uint3 prod_012    = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
280            const uint3 prod_01     = init_fastdiv_values((uint32_t) (ne0 * ne1));
281            const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
282            const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
283            const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
284
285            if constexpr (sizeof...(I) > 0) {
286                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
287                    src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
288                    ne12, ne13,
289                  /*s0,*/ s1,  s2,  s3,
290                    s00, s01, s02, s03,
291                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
292            } else {
293                k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
294                    <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
295                                                           ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
296                                                         /*s0,*/ s1,  s2,  s3,
297                                                           s00, s01, s02, s03,
298                                                           s10, s11, s12, s13);
299            }
300        } else {
301            const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
302            if constexpr (sizeof...(I) > 0) {
303                k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
304                    src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
305                  /*s0,*/ s1, s2,  s3,
306                    s00 ,s01, s02, s03,
307                    s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
308            } else {
309                k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
310                    src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
311                  /*s0,*/ s1,  s2,  s3,
312                    s00, s01, s02, s03,
313                    s10, s11, s12, s13);
314            }
315        }
316    }
317}
318
319template <typename T>
320static __global__ void k_repeat_back(
321    const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
322    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
323    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
324
325    const int64_t tid0  = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
326    const int64_t tid1  = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
327    const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
328    const int64_t tid2  = tid23 % ne2;
329    const int64_t tid3  = tid23 / ne2;
330
331    if (tid0 >= ne0) {
332        return;
333    }
334
335    T sum = 0;
336    for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
337        for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
338            for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
339                for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
340                    sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
341                }
342            }
343        }
344    }
345    dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
346}
347
348template <float (*bin_op)(const float, const float), int n_fuse = 1>
349struct bin_bcast_cuda {
350    template<typename src0_t, typename src1_t, typename dst_t>
351    void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
352            const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
353            cudaStream_t stream) {
354        launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
355            src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
356    }
357};
358
359template <typename T>
360static void repeat_back_cuda(
361    const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
362    const size_t s00, const size_t s01, const size_t s02, const size_t s03,
363    const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
364
365    const dim3 block_dims(WARP_SIZE, 1, 1);
366    const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
367    k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
368        (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
369}
370
371template<class op>
372static void ggml_cuda_op_bin_bcast(
373    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
374    const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
375
376    GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
377
378    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
379        op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
380    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
381        op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
382    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
383        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
384    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
385        op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
386    } else {
387        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
388            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
389        GGML_ABORT("fatal error");
390    }
391}
392
393void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
394    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
395}
396
397void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
398    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
399}
400
401void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
402    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
403}
404
405void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
406    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
407}
408
409void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
410    ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
411}
412
413template <float (*op)(const float, const float), int n_fuse>
414static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
415    cudaStream_t stream = ctx.stream();
416
417    const ggml_tensor * src0 = dst->src[0];
418    const ggml_tensor * src1 = dst->src[1];
419
420    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
421        launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
422            (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
423            stream, std::make_index_sequence<n_fuse>{});
424    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
425        launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
426            (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
427            stream, std::make_index_sequence<n_fuse>{});
428    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
429        launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
430            (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
431            stream, std::make_index_sequence<n_fuse>{});
432    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
433        launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
434            (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
435            stream, std::make_index_sequence<n_fuse>{});
436    } else {
437        fprintf(stderr,
438                "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
439                __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
440        GGML_ABORT("fatal error");
441    }
442}
443
444
445void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
446    GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
447
448    switch (n_fuse) {
449        case 2:
450            ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
451            break;
452        case 3:
453            ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
454            break;
455        case 4:
456            ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
457            break;
458        case 5:
459            ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
460            break;
461        case 6:
462            ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
463            break;
464        case 7:
465            ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
466            break;
467        case 8:
468            ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
469            break;
470        default:
471            GGML_ASSERT(false && "Unsupported n_fuse value");
472    }
473}
474
475void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
476    const ggml_tensor * src0 = dst->src[0];
477
478    GGML_ASSERT(src0->type == dst->type);
479    GGML_ASSERT(ggml_is_contiguous(dst));
480    GGML_ASSERT(ggml_can_repeat(dst, src0));
481
482    cudaStream_t stream = ctx.stream();
483
484    GGML_TENSOR_UNARY_OP_LOCALS;
485
486    GGML_ASSERT(ne2*ne3 <= (1 << 15));
487
488    const size_t ts = ggml_type_size(src0->type);
489    const size_t s00 = nb00 / ts;
490    const size_t s01 = nb01 / ts;
491    const size_t s02 = nb02 / ts;
492    const size_t s03 = nb03 / ts;
493
494    switch (dst->type) {
495        case GGML_TYPE_F32: {
496            const float * src0_d = (const float *) src0->data;
497            float       * dst_d  = (float       *) dst->data;
498            repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
499        } break;
500        default: {
501            GGML_ASSERT(false);
502        } break;
503    }
504}