diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/mean.cl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-opencl/kernels/mean.cl | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/mean.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/mean.cl new file mode 100644 index 0000000..5c3e8bc --- /dev/null +++ b/llama.cpp/ggml/src/ggml-opencl/kernels/mean.cl @@ -0,0 +1,39 @@ + +kernel void kernel_mean_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float *)((global char *)src0 + offset0); + dst = (global float *)((global char *)dst + offsetd); + + int i3 = get_global_id(2); + int i2 = get_global_id(1); + int i1 = get_global_id(0); + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum / ne00; +} |
