summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl')
-rw-r--r--llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl187
1 files changed, 187 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl b/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl
new file mode 100644
index 0000000..c2962ed
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-opencl/kernels/get_rows.cl
@@ -0,0 +1,187 @@
+#pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+typedef char int8_t;
+typedef uchar uint8_t;
+typedef short int16_t;
+typedef ushort uint16_t;
+typedef int int32_t;
+typedef uint uint32_t;
+
+#define QK4_0 32
+
+//------------------------------------------------------------------------------
+// block_q4_0
+//------------------------------------------------------------------------------
+struct block_q4_0
+{
+ half d;
+ uint8_t qs[QK4_0 / 2];
+};
+
+
+//------------------------------------------------------------------------------
+// dequantize_q4_0_f32, dequantize_q4_0_f16
+//------------------------------------------------------------------------------
+void dequantize_q4_0_f32(global struct block_q4_0 * xb, short il, float16 * reg) {
+ global ushort * qs = ((global ushort *)xb + 1);
+ float d1 = il ? (xb->d / 16.h) : xb->d;
+ float d2 = d1 / 256.f;
+ float md = -8.h * xb->d;
+ ushort mask0 = il ? 0x00F0 : 0x000F;
+ ushort mask1 = mask0 << 8;
+
+ reg->s0 = d1 * (qs[0] & mask0) + md;
+ reg->s1 = d2 * (qs[0] & mask1) + md;
+
+ reg->s2 = d1 * (qs[1] & mask0) + md;
+ reg->s3 = d2 * (qs[1] & mask1) + md;
+
+ reg->s4 = d1 * (qs[2] & mask0) + md;
+ reg->s5 = d2 * (qs[2] & mask1) + md;
+
+ reg->s6 = d1 * (qs[3] & mask0) + md;
+ reg->s7 = d2 * (qs[3] & mask1) + md;
+
+ reg->s8 = d1 * (qs[4] & mask0) + md;
+ reg->s9 = d2 * (qs[4] & mask1) + md;
+
+ reg->sa = d1 * (qs[5] & mask0) + md;
+ reg->sb = d2 * (qs[5] & mask1) + md;
+
+ reg->sc = d1 * (qs[6] & mask0) + md;
+ reg->sd = d2 * (qs[6] & mask1) + md;
+
+ reg->se = d1 * (qs[7] & mask0) + md;
+ reg->sf = d2 * (qs[7] & mask1) + md;
+}
+
+
+//------------------------------------------------------------------------------
+// get_rows
+//------------------------------------------------------------------------------
+kernel void kernel_get_rows_f32(
+ global void * src0,
+ ulong offset0,
+ global int * src1,
+ ulong offset1,
+ global float * dst,
+ ulong offsetd,
+ int ne00,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne10,
+ ulong nb10,
+ ulong nb11,
+ ulong nb12,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = (global void*)((global char*)src0 + offset0);
+ src1 = (global int*)((global char*)src1 + offset1);
+ dst = (global float*)((global char*)dst + offsetd);
+
+ int i10 = get_group_id(0);
+ int i11 = get_group_id(1);
+ int i12 = get_group_id(2);
+
+ int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
+
+ int i02 = i11;
+ int i03 = i12;
+
+ for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
+ if (ind >= ne00) {
+ return;
+ }
+ ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
+ ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
+ }
+}
+
+kernel void kernel_get_rows_f16(
+ global void * src0,
+ ulong offset0,
+ global int * src1,
+ ulong offset1,
+ global float * dst,
+ ulong offsetd,
+ int ne00,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne10,
+ ulong nb10,
+ ulong nb11,
+ ulong nb12,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = (global void*)((global char*)src0 + offset0);
+ src1 = (global int*)((global char*)src1 + offset1);
+ dst = (global float*)((global char*)dst + offsetd);
+
+ int i10 = get_group_id(0);
+ int i11 = get_group_id(1);
+ int i12 = get_group_id(2);
+
+ int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
+
+ int i02 = i11;
+ int i03 = i12;
+
+ for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
+ if (ind >= ne00) {
+ return;
+ }
+ ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
+ ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
+ }
+}
+
+kernel void kernel_get_rows_q4_0(
+ global void * src0,
+ ulong offset0,
+ global int * src1,
+ ulong offset1,
+ global float * dst,
+ ulong offsetd,
+ int ne00,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne10,
+ ulong nb10,
+ ulong nb11,
+ ulong nb12,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = (global void*)((global char*)src0 + offset0);
+ src1 = (global int*)((global char*)src1 + offset1);
+ dst = (global float*)((global char*)dst + offsetd);
+
+ const int NL = 2;
+
+ int i10 = get_group_id(0);
+ int i11 = get_group_id(1);
+ int i12 = get_group_id(2);
+
+ int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
+
+ int i02 = i11;
+ int i03 = i12;
+
+ for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
+ float16 temp;
+ if (ind >= ne00) {
+ return;
+ }
+ dequantize_q4_0_f32(
+ ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
+ *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}