summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl77
1 files changed, 77 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl
new file mode 100644
index 0000000..7ae21ac
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/ssm_conv.cl
@@ -0,0 +1,77 @@
+kernel void kernel_ssm_conv_f32_f32(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb00,
+ ulong nb01,
+ ulong nb02,
+ int ne10,
+ ulong nb11,
+ ulong nb0,
+ ulong nb1,
+ ulong nb2
+){
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ int ir = get_global_id(0);
+ int i2 = get_global_id(1);
+ int i3 = get_global_id(2);
+
+ int nc = ne10;
+
+ global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
+ global float * c = (global float *) (src1 + ir*nb11);
+ global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
+
+ float sumf = 0.0f;
+
+ for (int i0 = 0; i0 < nc; ++i0) {
+ sumf += s[i0] * c[i0];
+ }
+
+ d[0] = sumf;
+}
+
+kernel void kernel_ssm_conv_f32_f32_4(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ ulong nb00,
+ ulong nb01,
+ ulong nb02,
+ int ne10,
+ ulong nb11,
+ ulong nb0,
+ ulong nb1,
+ ulong nb2
+) {
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ int ir = get_global_id(0);
+ int i2 = get_global_id(1);
+ int i3 = get_global_id(2);
+
+ int nc = ne10;
+
+ global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
+ global float4 * c = (global float4 *) (src1 + ir*nb11);
+ global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
+
+ float sumf = 0.0f;
+
+ for (int i0 = 0; i0 < nc/4; ++i0) {
+ sumf += dot(s[i0], c[i0]);
+ }
+
+ d[0] = sumf;
+}