diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp | 116 |
1 files changed, 116 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 0000000..db14f5a --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,116 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +#include "rte.glsl" +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + BDA_STORAGE_T dst_addr; + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; + uint batch_IC; +} p; + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + +void im2col(const uint y, const uint z) { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = y; + const uint batch = z / p.IC; + const uint ic = z % p.IC; + + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * p.KH; + + const uint base_linear_idx = gidx * NUM_ITER; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + + A_TYPE values[NUM_ITER]; + BDA_OFFSET_T offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + + offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; + + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; + } + + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == p.KH) { + current_ky = 0; + current_kx++; + } + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); + dst_addr.d = D_TYPE(values[idx]); +#else + data_d[offset_dst[idx]] = D_TYPE(values[idx]); +#endif + } +} + +void main() { + uint y = gl_GlobalInvocationID.y; + while (y < p.OH) { + uint z = gl_GlobalInvocationID.z; + while (z < p.batch_IC) { + im2col(y, z); + z += gl_NumWorkGroups.z; + } + y += gl_NumWorkGroups.y; + } +} |
