diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl new file mode 100644 index 0000000..36eff04 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl | |||
| @@ -0,0 +1,58 @@ | |||
| 1 | #pragma OPENCL EXTENSION cl_khr_fp16 : enable | ||
| 2 | |||
| 3 | //------------------------------------------------------------------------------ | ||
| 4 | // diag_mask_inf kernels | ||
| 5 | //------------------------------------------------------------------------------ | ||
| 6 | kernel void kernel_diag_mask_inf( | ||
| 7 | global float * src0, | ||
| 8 | ulong offset0, | ||
| 9 | global float * dst, | ||
| 10 | ulong offsetd, | ||
| 11 | int ne00, | ||
| 12 | int ne01, | ||
| 13 | int n_past | ||
| 14 | ) { | ||
| 15 | src0 = (global float*)((global char*)src0 + offset0); | ||
| 16 | dst = (global float*)((global char*)dst + offsetd); | ||
| 17 | |||
| 18 | int i02 = get_global_id(2); | ||
| 19 | int i01 = get_global_id(1); | ||
| 20 | int i00 = get_global_id(0); | ||
| 21 | |||
| 22 | if (i00 > n_past + i01) { | ||
| 23 | dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; | ||
| 24 | } else { | ||
| 25 | dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; | ||
| 26 | } | ||
| 27 | } | ||
| 28 | |||
| 29 | kernel void kernel_diag_mask_inf_8( | ||
| 30 | global float4 * src0, | ||
| 31 | ulong offset0, | ||
| 32 | global float4 * dst, | ||
| 33 | ulong offsetd, | ||
| 34 | int ne00, | ||
| 35 | int ne01, | ||
| 36 | int n_past | ||
| 37 | ) { | ||
| 38 | src0 = (global float4*)((global char*)src0 + offset0); | ||
| 39 | dst = (global float4*)((global char*)dst + offsetd); | ||
| 40 | |||
| 41 | int i = 2*get_global_id(0); | ||
| 42 | |||
| 43 | dst[i+0] = src0[i+0]; | ||
| 44 | dst[i+1] = src0[i+1]; | ||
| 45 | int i4 = 4*i; | ||
| 46 | int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; | ||
| 47 | int i01 = i4/(ne00); i4 -= i01*ne00; | ||
| 48 | int i00 = i4; | ||
| 49 | for (int k = 3; k >= 0; --k) { | ||
| 50 | if (i00 + 4 + k <= n_past + i01) { | ||
| 51 | break; | ||
| 52 | } | ||
| 53 | (&dst[i+1])[k] = -INFINITY; | ||
| 54 | if (i00 + k > n_past + i01) { | ||
| 55 | (&dst[i])[k] = -INFINITY; | ||
| 56 | } | ||
| 57 | } | ||
| 58 | } | ||
