diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 0000000..1f05f92 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} |
