1#include "unary.cuh"
  2#include "convert.cuh"
  3
  4static __device__ __forceinline__ float op_abs(float x) {
  5    return fabsf(x);
  6}
  7
  8static __device__ __forceinline__ float op_sgn(float x) {
  9    return (x > 0.f ? 1.f : ((x < 0.f ? -1.f : 0.f)));
 10}
 11
 12static __device__ __forceinline__ float op_neg(float x) {
 13    return -x;
 14}
 15
 16static __device__ __forceinline__ float op_step(float x) {
 17    return x > 0.0f;
 18}
 19
 20static __device__ __forceinline__ float op_gelu(float x) {
 21    return ggml_cuda_op_gelu_single(x);
 22}
 23
 24static __device__ __forceinline__ float op_gelu_erf(float x) {
 25    const float SQRT_2_INV = 0.70710678118654752440084436210484f;
 26
 27    return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
 28}
 29
 30static __device__ __forceinline__ float op_gelu_quick(float x) {
 31    const float GELU_QUICK_COEF = -1.702f;
 32
 33    return x * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x)));
 34}
 35
 36static __device__ __forceinline__ float op_silu(float x) {
 37    return ggml_cuda_op_silu_single(x);
 38}
 39
 40static __device__ __forceinline__ float op_tanh(float x) {
 41    return tanhf(x);
 42}
 43
 44static __device__ __forceinline__ float op_relu(float x) {
 45    return fmaxf(x, 0);
 46}
 47
 48static __device__ __forceinline__ float op_sigmoid(float x) {
 49    return 1.0f / (1.0f + expf(-x));
 50}
 51
 52static __device__ __forceinline__ float op_hardsigmoid(float x) {
 53    return fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
 54}
 55
 56static __device__ __forceinline__ float op_hardswish(float x) {
 57    return x * fminf(1.0f, fmaxf(0.0f, (x + 3.0f) / 6.0f));
 58}
 59
 60static __device__ __forceinline__ float op_exp(float x) {
 61    return expf(x);
 62}
 63
 64static __device__ __forceinline__ float op_sqr(float x) {
 65    return x * x;
 66}
 67
 68static __device__ __forceinline__ float op_sqrt(float x) {
 69    return sqrtf(x);
 70}
 71
 72static __device__ __forceinline__ float op_sin(float x) {
 73    return sinf(x);
 74}
 75
 76static __device__ __forceinline__ float op_cos(float x) {
 77    return cosf(x);
 78}
 79
 80static __device__ __forceinline__ float op_log(float x) {
 81    return logf(x);
 82}
 83
 84static __device__ __forceinline__ float op_expm1(float x) {
 85    return expm1f(x);
 86}
 87
 88static __device__ __forceinline__ float op_softplus(float x) {
 89    return (x > 20.0f) ? x : logf(1.0f + expf(x));
 90}
 91
 92static __device__ __forceinline__ float op_elu(float x) {
 93    return (x > 0.f) ? x : expm1f(x);
 94}
 95
 96static __device__ __forceinline__ float op_floor(float x) {
 97    return floorf(x);
 98}
 99
100static __device__ __forceinline__ float op_ceil(float x) {
101    return ceilf(x);
102}
103
104static __device__ __forceinline__ float op_round(float x) {
105    return round(x);
106}
107
108static __device__ __forceinline__ float op_trunc(float x) {
109    return trunc(x);
110}
111
112template <float (*op)(float), typename T>
113static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
114    const int i = blockDim.x*blockIdx.x + threadIdx.x;
115
116    if (i >= k) {
117        return;
118    }
119
120    dst[i] = (T)op((float)x[i]);
121}
122
123template <float (*op)(float), typename T>
124static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) {
125    const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE;
126    unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k);
127}
128
129template <float (*op)(float)>
130void ggml_cuda_op_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
131    const ggml_tensor * src0 = dst->src[0];
132    const void * src0_d = src0->data;
133    void * dst_d = dst->data;
134    cudaStream_t stream = ctx.stream();
135
136    GGML_ASSERT(ggml_is_contiguous(src0));
137
138    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
139    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
140    GGML_ASSERT(src0->type == dst->type);
141
142    if (src0->type == GGML_TYPE_F16) {
143        unary_cuda<op>((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), stream);
144    } else {
145        unary_cuda<op>((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), stream);
146    }
147}
148
149void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
150    ggml_cuda_op_unary<op_abs>(ctx, dst);
151}
152
153void ggml_cuda_op_sgn(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
154    ggml_cuda_op_unary<op_sgn>(ctx, dst);
155}
156
157void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
158    ggml_cuda_op_unary<op_neg>(ctx, dst);
159}
160
161void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
162    ggml_cuda_op_unary<op_step>(ctx, dst);
163}
164
165void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
166    ggml_cuda_op_unary<op_gelu>(ctx, dst);
167}
168
169void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
170    ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
171}
172
173void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
174    ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
175}
176
177void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
178    ggml_cuda_op_unary<op_silu>(ctx, dst);
179}
180
181void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
182    ggml_cuda_op_unary<op_tanh>(ctx, dst);
183}
184
185void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
186    ggml_cuda_op_unary<op_relu>(ctx, dst);
187}
188
189void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
190    ggml_cuda_op_unary<op_sigmoid>(ctx, dst);
191}
192
193void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
194    ggml_cuda_op_unary<op_hardsigmoid>(ctx, dst);
195}
196
197void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
198    ggml_cuda_op_unary<op_hardswish>(ctx, dst);
199}
200
201void ggml_cuda_op_exp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
202    ggml_cuda_op_unary<op_exp>(ctx, dst);
203}
204
205void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
206    ggml_cuda_op_unary<op_sqr>(ctx, dst);
207}
208
209void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
210    ggml_cuda_op_unary<op_sqrt>(ctx, dst);
211}
212
213void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
214    ggml_cuda_op_unary<op_sin>(ctx, dst);
215}
216
217void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
218    ggml_cuda_op_unary<op_cos>(ctx, dst);
219}
220
221void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
222    ggml_cuda_op_unary<op_log>(ctx, dst);
223}
224
225void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
226    ggml_cuda_op_unary<op_elu>(ctx, dst);
227}
228
229void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
230    ggml_cuda_op_unary<op_floor>(ctx, dst);
231}
232
233void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
234    ggml_cuda_op_unary<op_ceil>(ctx, dst);
235}
236
237void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
238    ggml_cuda_op_unary<op_round>(ctx, dst);
239}
240
241void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
242    ggml_cuda_op_unary<op_trunc>(ctx, dst);
243}
244
245void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
246    ggml_cuda_op_unary<op_expm1>(ctx, dst);
247}
248
249void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
250    ggml_cuda_op_unary<op_softplus>(ctx, dst);
251}
252/* gated ops */
253
254template <float (*op)(float), typename T>
255static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
256    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
257
258    if (i >= k) {
259        return;
260    }
261
262    // perform base op and multiply with gate (either offset in same tensor or a separate one)
263    const int64_t j0 = (i / n) * o0 + (i % n);
264    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
265
266    dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
267}
268
269template <float (*op)(float), typename T>
270static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
271    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
272    unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
273}
274
275template <float (*op)(float)>
276void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
277    const ggml_tensor * src0 = dst->src[0];
278    const ggml_tensor * src1 = dst->src[1];
279    void * src0_d = src0->data;
280    void * src1_d = src1 ? src1->data : src0->data;
281    const int64_t src0_o = src0->nb[1];
282    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
283    void * dst_d = dst->data;
284    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
285    cudaStream_t stream = ctx.stream();
286
287    GGML_ASSERT(ggml_is_contiguous_1(src0));
288    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
289    GGML_ASSERT(ggml_is_contiguous(dst));
290
291    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
292    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
293    GGML_ASSERT(src0->type == dst->type);
294    GGML_ASSERT(dst->ne[0] == nc);
295    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
296
297    if (src1) {
298        GGML_ASSERT(ggml_is_contiguous_1(src1));
299        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
300        GGML_ASSERT(src1->ne[0] == nc);
301        GGML_ASSERT(src0->type == src1->type);
302    }
303
304    const int32_t swapped = ((const int32_t *) dst->op_params)[1];
305
306    if (src0->type == GGML_TYPE_F16) {
307        half * src0_p = (half *) src0_d;
308        half * src1_p = (half *) src1_d;
309
310        if (!src1) {
311            src0_p += swapped ? nc : 0;
312            src1_p += swapped ? 0 : nc;
313        }
314
315        unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
316    } else {
317        float * src0_p = (float *) src0_d;
318        float * src1_p = (float *) src1_d;
319
320        if (!src1) {
321            src0_p += swapped ? nc : 0;
322            src1_p += swapped ? 0 : nc;
323        }
324
325        unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
326    }
327}
328
329void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
330    ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
331}
332
333void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
334    ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
335}
336
337void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
338    ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
339}
340
341void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
342    ggml_cuda_op_unary_gated<op_gelu_erf>(ctx, dst);
343}
344
345void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
346    ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
347}
348
349// swiglu_oai
350
351template <typename T>
352static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
353    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
354
355    if (i >= k) {
356        return;
357    }
358
359    // perform base op and multiply with gate (either offset in same tensor or a separate one)
360    const int64_t j0 = (i / n) * o0 + (i % n);
361    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
362
363    float xi = x[j0];
364    float gi = g[j1];
365
366    dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
367}
368
369template <typename T>
370static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
371    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
372    swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
373}
374
375void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
376    const ggml_tensor * src0 = dst->src[0];
377    const ggml_tensor * src1 = dst->src[1];
378    void * src0_d = src0->data;
379    void * src1_d = src1 ? src1->data : src0->data;
380    const int64_t src0_o = src0->nb[1];
381    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
382    void * dst_d = dst->data;
383    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
384    cudaStream_t stream = ctx.stream();
385
386    GGML_ASSERT(ggml_is_contiguous_1(src0));
387    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
388    GGML_ASSERT(ggml_is_contiguous(dst));
389
390    GGML_ASSERT(src0->type == GGML_TYPE_F32);
391    GGML_ASSERT( dst->type == GGML_TYPE_F32);
392    GGML_ASSERT(src0->type == dst->type);
393    GGML_ASSERT(dst->ne[0] == nc);
394    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
395
396    if (src1) {
397        GGML_ASSERT(ggml_is_contiguous_1(src1));
398        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
399        GGML_ASSERT(src1->ne[0] == nc);
400        GGML_ASSERT(src0->type == src1->type);
401    }
402
403    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
404    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
405    const float alpha = ggml_get_op_params_f32(dst, 2);
406    const float limit = ggml_get_op_params_f32(dst, 3);
407
408    float * src0_p = (float *) src0_d;
409    float * src1_p = (float *) src1_d;
410
411    if (!src1) {
412        src0_p += swapped ? nc : 0;
413        src1_p += swapped ? 0 : nc;
414    }
415
416    swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
417}
418
419/* CUDA kernel + launcher for xIELU */
420
421template <typename T>
422static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
423    const int i = blockDim.x*blockIdx.x + threadIdx.x;
424
425    if (i >= k) {
426        return;
427    }
428
429    const float xi = ggml_cuda_cast<float>(x[i]);
430
431    const float gate_pos = (xi > 0.0f);
432    const float y_pos = alpha_p * xi * xi + beta * xi;
433    const float min_v_eps = fminf(xi, eps);
434    const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
435    const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
436
437    dst[i] = ggml_cuda_cast<T>(out);
438}
439
440template <typename T>
441static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
442    const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
443    xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);
444}
445
446void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
447    const ggml_tensor * src0 = dst->src[0];
448    const void * src0_d = src0->data;
449    void * dst_d = dst->data;
450    cudaStream_t stream = ctx.stream();
451
452    GGML_ASSERT(ggml_is_contiguous(src0));
453
454    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
455    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
456    GGML_ASSERT(src0->type == dst->type);
457
458    const float alpha_n = ggml_get_op_params_f32(dst, 1);
459    const float alpha_p = ggml_get_op_params_f32(dst, 2);
460    const float beta    = ggml_get_op_params_f32(dst, 3);
461    const float eps     = ggml_get_op_params_f32(dst, 4);
462
463    if (src0->type == GGML_TYPE_F16) {
464        xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
465    } else {
466        xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
467    }
468}
469
470
471
472/* silu_back */
473
474static __device__ __forceinline__ float op_silu_back(float grad, float x) {
475    const float s = 1.0f / (1.0f + expf(-x));
476    return grad * s * (1.0f + x * (1.0f - s));
477}
478
479template <class T>
480static __global__ void silu_back_kernel(const T * grad, const T * xf, T * dst, const int k) {
481    const int i = blockDim.x*blockIdx.x + threadIdx.x;
482
483    if (i >= k) {
484        return;
485    }
486
487    dst[i] = (T)op_silu_back((float)grad[i], (float)xf[i]);
488}
489
490template <class T>
491static void silu_back_cuda(const T * grad, const T * x, T * dst, const int k, cudaStream_t stream) {
492    const int num_blocks = (k + CUDA_SILU_BACK_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
493    silu_back_kernel<<<num_blocks, CUDA_SILU_BACK_BLOCK_SIZE, 0, stream>>>(grad, x, dst, k);
494}
495
496void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
497    const ggml_tensor * src0 = dst->src[0]; // input from forward pass
498    const ggml_tensor * src1 = dst->src[1]; // grads of forward pass output
499
500    const float * src0_d = (const float *) src0->data;
501    const float * src1_d = (const float *) src1->data;
502    float       * dst_d  = (float       *) dst->data;
503
504    cudaStream_t stream = ctx.stream();
505
506    GGML_ASSERT(ggml_is_contiguous(src0));
507
508    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
509    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
510    GGML_ASSERT(src0->type == dst->type);
511
512    if (src0->type == GGML_TYPE_F16) {
513        silu_back_cuda((const half *)src0_d, (const half *)src1_d, (half *)dst_d, ggml_nelements(src0), stream);
514    } else {
515        silu_back_cuda((const float*)src0_d, (const float*)src1_d, (float *)dst_d, ggml_nelements(src0), stream);
516    }
517}
518
519/* leaky relu */
520
521static __device__ __forceinline__ float op_leaky_relu(float x, const float negative_slope) {
522    return fmaxf(x, 0) + fminf(x, 0.0f) * negative_slope;
523}
524
525template <class T>
526static __global__ void leaky_relu_kernel(const T * x, T * dst, const int k, const float negative_slope) {
527    const int i  = blockDim.x*blockIdx.x + threadIdx.x;
528
529    if (i >= k) {
530        return;
531    }
532
533    dst[i] = (T)op_leaky_relu((float)x[i], negative_slope);
534}
535
536template <class T>
537static void leaky_relu_cuda(const T * x, T * dst, const int k, const float negative_slope, cudaStream_t stream) {
538    const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
539    leaky_relu_kernel<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
540}
541
542void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
543    const ggml_tensor * src0 = dst->src[0];
544    const void * src0_d = src0->data;
545    void * dst_d = dst->data;
546    cudaStream_t stream = ctx.stream();
547
548    GGML_ASSERT(ggml_is_contiguous(src0));
549
550    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
551    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
552    GGML_ASSERT(src0->type == dst->type);
553
554    float negative_slope;
555    memcpy(&negative_slope, dst->op_params, sizeof(float));
556
557    if (src0->type == GGML_TYPE_F16) {
558        leaky_relu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), negative_slope, stream);
559    } else {
560        leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream);
561    }
562}