summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp22
1 files changed, 22 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
new file mode 100644
index 0000000..1251f9c
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp
@@ -0,0 +1,22 @@
+#version 450
+
+#include "generic_head.glsl"
+
+layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) buffer X {A_TYPE data_x[];};
+layout (binding = 1) readonly buffer G {A_TYPE data_grad[];};
+layout (binding = 2) readonly buffer P {float data_params[2];};
+
+void main() {
+ const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
+
+ if (i >= p.KX) {
+ return;
+ }
+
+ const float alpha = data_params[0];
+ const float keep = 1.f - alpha * data_params[1];
+
+ data_x[i] = data_x[i] * keep - alpha * data_grad[i];
+}