aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl58
1 files changed, 58 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
new file mode 100644
index 0000000..36eff04
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
@@ -0,0 +1,58 @@
1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
3//------------------------------------------------------------------------------
4// diag_mask_inf kernels
5//------------------------------------------------------------------------------
6kernel void kernel_diag_mask_inf(
7 global float * src0,
8 ulong offset0,
9 global float * dst,
10 ulong offsetd,
11 int ne00,
12 int ne01,
13 int n_past
14) {
15 src0 = (global float*)((global char*)src0 + offset0);
16 dst = (global float*)((global char*)dst + offsetd);
17
18 int i02 = get_global_id(2);
19 int i01 = get_global_id(1);
20 int i00 = get_global_id(0);
21
22 if (i00 > n_past + i01) {
23 dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
24 } else {
25 dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
26 }
27}
28
29kernel void kernel_diag_mask_inf_8(
30 global float4 * src0,
31 ulong offset0,
32 global float4 * dst,
33 ulong offsetd,
34 int ne00,
35 int ne01,
36 int n_past
37) {
38 src0 = (global float4*)((global char*)src0 + offset0);
39 dst = (global float4*)((global char*)dst + offsetd);
40
41 int i = 2*get_global_id(0);
42
43 dst[i+0] = src0[i+0];
44 dst[i+1] = src0[i+1];
45 int i4 = 4*i;
46 int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
47 int i01 = i4/(ne00); i4 -= i01*ne00;
48 int i00 = i4;
49 for (int k = 3; k >= 0; --k) {
50 if (i00 + 4 + k <= n_past + i01) {
51 break;
52 }
53 (&dst[i+1])[k] = -INFINITY;
54 if (i00 + k > n_past + i01) {
55 (&dst[i])[k] = -INFINITY;
56 }
57 }
58}