summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl51
1 files changed, 51 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl
new file mode 100644
index 0000000..0c1b3d7
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/concat.cl
@@ -0,0 +1,51 @@
+kernel void kernel_concat_f32(
+ global const char * src0,
+ ulong offset0,
+ global const char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ int ne00,
+ int ne01,
+ int ne02,
+ int ne03,
+ ulong nb00,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ ulong nb10,
+ ulong nb11,
+ ulong nb12,
+ ulong nb13,
+ int ne0,
+ ulong nb0,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3,
+ int dim
+) {
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ const int i3 = get_group_id(2);
+ const int i2 = get_group_id(1);
+ const int i1 = get_group_id(0);
+
+ int o[4] = {0, 0, 0, 0};
+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+
+ global const float * x;
+
+ for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) {
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
+ } else {
+ x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+ }
+
+ global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ *y = *x;
+ }
+}