summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl58
1 files changed, 58 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
new file mode 100644
index 0000000..36eff04
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/diag_mask_inf.cl
@@ -0,0 +1,58 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+//------------------------------------------------------------------------------
+// diag_mask_inf kernels
+//------------------------------------------------------------------------------
+kernel void kernel_diag_mask_inf(
+ global float * src0,
+ ulong offset0,
+ global float * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int n_past
+) {
+ src0 = (global float*)((global char*)src0 + offset0);
+ dst = (global float*)((global char*)dst + offsetd);
+
+ int i02 = get_global_id(2);
+ int i01 = get_global_id(1);
+ int i00 = get_global_id(0);
+
+ if (i00 > n_past + i01) {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+ } else {
+ dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ }
+}
+
+kernel void kernel_diag_mask_inf_8(
+ global float4 * src0,
+ ulong offset0,
+ global float4 * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int n_past
+) {
+ src0 = (global float4*)((global char*)src0 + offset0);
+ dst = (global float4*)((global char*)dst + offsetd);
+
+ int i = 2*get_global_id(0);
+
+ dst[i+0] = src0[i+0];
+ dst[i+1] = src0[i+1];
+ int i4 = 4*i;
+ int i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+ int i01 = i4/(ne00); i4 -= i01*ne00;
+ int i00 = i4;
+ for (int k = 3; k >= 0; --k) {
+ if (i00 + 4 + k <= n_past + i01) {
+ break;
+ }
+ (&dst[i+1])[k] = -INFINITY;
+ if (i00 + k > n_past + i01) {
+ (&dst[i])[k] = -INFINITY;
+ }
+ }
+}