summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp116
1 files changed, 116 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
new file mode 100644
index 0000000..db14f5a
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp
@@ -0,0 +1,116 @@
+#version 450
+
+#extension GL_EXT_shader_16bit_storage : require
+#extension GL_EXT_control_flow_attributes : require
+
+#include "rte.glsl"
+#include "types.glsl"
+
+layout (push_constant) uniform parameter
+{
+ BDA_STORAGE_T dst_addr;
+ uint batch_offset; uint offset_delta;
+ uint IC;
+ uint IW; uint IH;
+ uint OW; uint OH;
+ uint KW; uint KH;
+ uint pelements;
+ uint CHW;
+ int s0; int s1;
+ int p0; int p1;
+ int d0; int d1;
+ uint batch_IC;
+} p;
+
+layout(constant_id = 0) const uint BLOCK_SIZE = 32;
+
+const uint NUM_ITER = 512 / BLOCK_SIZE;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
+
+#if BDA
+layout (buffer_reference) buffer D_ptr {D_TYPE d;};
+#endif
+
+void im2col(const uint y, const uint z) {
+ const uint gidx = gl_GlobalInvocationID.x;
+
+ const uint oh = y;
+ const uint batch = z / p.IC;
+ const uint ic = z % p.IC;
+
+ const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
+ const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
+ const int oh_s1 = int(oh) * p.s1;
+ const uint ksize = p.OW * p.KH;
+
+ const uint base_linear_idx = gidx * NUM_ITER;
+
+ uint current_kx = base_linear_idx / ksize;
+ const uint rem = base_linear_idx - (current_kx * ksize);
+ uint current_ky = rem / p.OW;
+ uint current_ix = rem % p.OW;
+
+ A_TYPE values[NUM_ITER];
+ BDA_OFFSET_T offset_dst[NUM_ITER];
+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+ values[idx] = A_TYPE(0);
+ }
+
+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+
+ const uint linear_idx = base_linear_idx + idx;
+
+ if (linear_idx >= p.pelements) {
+ continue;
+ }
+
+ const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
+ const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
+
+ offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
+
+ if ((iih < p.IH) && (iiw < p.IW)) {
+ values[idx] = data_a[src_base + iih * p.IW + iiw];
+ }
+
+ if (++current_ix == p.OW) {
+ current_ix = 0;
+ if (++current_ky == p.KH) {
+ current_ky = 0;
+ current_kx++;
+ }
+ }
+ }
+
+ [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
+
+ const uint linear_idx = base_linear_idx + idx;
+
+ if (linear_idx >= p.pelements) {
+ continue;
+ }
+
+#if BDA
+ D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
+ dst_addr.d = D_TYPE(values[idx]);
+#else
+ data_d[offset_dst[idx]] = D_TYPE(values[idx]);
+#endif
+ }
+}
+
+void main() {
+ uint y = gl_GlobalInvocationID.y;
+ while (y < p.OH) {
+ uint z = gl_GlobalInvocationID.z;
+ while (z < p.batch_IC) {
+ im2col(y, z);
+ z += gl_NumWorkGroups.z;
+ }
+ y += gl_NumWorkGroups.y;
+ }
+}