diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 0000000..1605565 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + const uint half_dim = p.dim / 2; + + if (p.dim % 2 != 0 && j == half_dim) { + data_d[d_offset + 2 * half_dim] = 0.f; + } + + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} |
