summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl295
1 files changed, 295 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
new file mode 100644
index 0000000..84dc8db
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl
@@ -0,0 +1,295 @@
+#define(VARIANTS)
+
+[
+ {
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_inplace",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
+ },
+ {
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_inplace",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_ff",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f32_ff_inplace",
+ "REPLS": {
+ "TYPE" : "f32",
+ },
+ "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_ff",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"]
+ },
+ {
+ "SHADER_SUFFIX": "f16_ff_inplace",
+ "REPLS": {
+ "TYPE" : "f16",
+ },
+ "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"]
+ }
+]
+
+#end(VARIANTS)
+
+#define(DECLS)
+
+#decl(ROTATE)
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+ dst[i_dst0] = {{TYPE}}(out0);
+ dst[i_dst1] = {{TYPE}}(out1);
+}
+#enddecl(ROTATE)
+
+#decl(ROTATE_INPLACE)
+fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) {
+ src0[i_dst0] = {{TYPE}}(out0);
+ src0[i_dst1] = {{TYPE}}(out1);
+}
+#enddecl(ROTATE_INPLACE)
+
+#decl(NO_FF_FUNC)
+fn freq_factor(i: u32) -> f32 {
+ return 1.0f;
+}
+#enddecl(NO_FF_FUNC)
+
+#decl(FF_FUNC)
+fn freq_factor(i: u32) -> f32 {
+ return src2[params.offset_src2 + i/2];
+}
+#enddecl(FF_FUNC)
+
+#decl(NO_FF_BINDINGS)
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(NO_FF_BINDINGS)
+
+#decl(NO_FF_BINDINGS_INPLACE)
+
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#enddecl(NO_FF_BINDINGS_INPLACE)
+
+#decl(FF_BINDINGS)
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> dst: array<{{TYPE}}>;
+
+@group(0) @binding(4)
+var<uniform> params: Params;
+
+#enddecl(FF_BINDINGS)
+
+#decl(FF_BINDINGS_INPLACE)
+
+@group(0) @binding(2)
+var<storage, read_write> src2: array<f32>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+#enddecl(FF_BINDINGS_INPLACE)
+
+#end(DECLS)
+
+#define(SHADER)
+
+enable f16;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_src2: u32,
+ offset_dst: u32,
+
+ // Strides (in elements)
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ n_threads: u32,
+ ne0: u32,
+ ne1: u32,
+ ne2: u32,
+
+ n_dims: u32,
+ mode: u32,
+ theta_scale: f32,
+ attn_factor: f32,
+ freq_scale: f32,
+ ext_factor: f32,
+ corr_dim0: f32,
+ corr_dim1: f32,
+ sections0: u32,
+ sections1: u32,
+ sections2: u32,
+ sections3: u32
+};
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<{{TYPE}}>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1: array<i32>;
+
+DECLS
+
+fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 {
+ let y = (f32(i / 2) - low) / max(0.001f, high - low);
+ return 1.0f - min(1.0f, max(0.0f, y));
+}
+
+// returns vector of (cos_theta, sin_theta)
+// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row
+fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> {
+ var mscale = params.attn_factor;
+ var theta = params.freq_scale * theta_extrap;
+ if (params.ext_factor != 0.0f) {
+ let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor;
+ theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix;
+ mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale);
+ }
+ return vec2<f32>(cos(theta) * mscale, sin(theta) * mscale);
+}
+
+fn pair_base(i0: u32, div_2: bool) -> u32 {
+ if (div_2) {
+ return i0 / 2;
+ } else {
+ return i0;
+ }
+}
+
+fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 {
+ if (is_vision) {
+ return params.n_dims;
+ } else if (is_neox || is_mrope) {
+ return params.n_dims / 2;
+ } else {
+ return 1;
+ }
+}
+
+override wg_size: u32;
+@compute @workgroup_size(wg_size)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ // two elements per thread
+ if (gid.x >= params.n_threads) {
+ return;
+ }
+
+ let is_neox = bool(params.mode & 2);
+ let is_mrope = bool(params.mode & 8);
+ let is_imrope = params.mode == 40;
+ let is_vision = params.mode == 24;
+
+ var i = gid.x * 2; // start index for this thread
+ let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+ i = i % (params.ne2 * params.ne1 * params.ne0);
+ let i2 = i / (params.ne1 * params.ne0);
+ i = i % (params.ne1 * params.ne0);
+ let i1 = i / params.ne0;
+ let i0 = i % params.ne0;
+
+ let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01;
+ let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+
+ if (i0 >= params.n_dims && !is_vision) {
+ let i_src = i_src_row + i0;
+ let i_dst = i_dst_row + i0;
+ rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1]));
+ return;
+ }
+
+ var theta_base_mult: u32 = 0;
+ var theta_scale_pwr: u32 = i0 / 2;
+ if (is_mrope) {
+ let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3;
+ let sec_w = params.sections1 + params.sections0;
+ let sec_e = params.sections2 + sec_w;
+ let sector = (i0 / 2) % sect_dims;
+ if (is_imrope) {
+ if (sector % 3 == 1 && sector < 3 * params.sections1) {
+ theta_base_mult = 1;
+ } else if (sector % 3 == 2 && sector < 3 * params.sections2) {
+ theta_base_mult = 2;
+ } else if (sector % 3 == 0 && sector < 3 * params.sections0) {
+ theta_base_mult = 0;
+ } else {
+ theta_base_mult = 3;
+ }
+ } else {
+ if (sector >= params.sections0 && sector < sec_w) {
+ theta_base_mult = 1;
+ if (is_vision) {
+ theta_scale_pwr = sector - params.sections0;
+ }
+ } else if (sector >= sec_w && sector < sec_e) {
+ theta_base_mult = 2;
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_w;
+ }
+ } else if (sector >= sec_e) {
+ if (is_vision) {
+ theta_scale_pwr = sector - sec_e;
+ theta_scale_pwr = (i0 / 2) % sec_e;
+ }
+ theta_base_mult = 3;
+ } else if (is_vision) {
+ theta_scale_pwr = sector;
+ }
+ }
+ }
+ let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr));
+ let thetas = rope_yarn(theta_base/freq_factor(i0), i0);
+
+ let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision);
+ let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision);
+
+ let x0 = f32(src0[i_src]);
+ let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]);
+ rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x);
+}
+
+#end(SHADER)