diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp new file mode 100644 index 0000000..e88bdd0 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = gl_GlobalInvocationID.x; + + if (i00 >= p.ne00) { + return; + } + + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; + + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + +#if defined(DATA_A_BF16) + TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00])); +#else + TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); +#endif + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } +} |
