diff options
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl')
| -rw-r--r-- | llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl | 109 |
1 files changed, 109 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl new file mode 100644 index 0000000..99e9192 --- /dev/null +++ b/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -0,0 +1,109 @@ +enable f16; + +#ifdef DST_F32 +#define DST_INNER_TYPE f32 +#else +#define DST_INNER_TYPE f16 +#endif + +#ifdef VEC4 +#define SRC_TYPE vec4<f32> +#define DST_TYPE vec4<DST_INNER_TYPE> +#define VEC_SIZE 4 +#else +#define SRC_TYPE f32 +#define DST_TYPE DST_INNER_TYPE +#define VEC_SIZE 1 +#endif + +@group(0) @binding(0) +var<storage, read_write> src: array<SRC_TYPE>; + +@group(0) @binding(1) +var<storage, read_write> idx: array<u32>; + +@group(0) @binding(2) +var<storage, read_write> dst: array<DST_TYPE>; + +#ifdef I64_IDX +@group(0) @binding(3) +var<storage, read_write> error: atomic<u32>; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) { + return; + } + + // getting the row from gid + let elems_per_row = params.ne0 / VEC_SIZE; + var i = gid.x / elems_per_row; + + let i_src3 = i / (params.ne2 * params.n_rows); + + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + +#ifdef I64_IDX + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; + + let idx_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1]; + + if (idx_low_val != 0) { + // Upper bits of index are not zero, output will be incorrect + atomicStore(&error, 1); + return; + } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif + + let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + + let col_idx = (gid.x % elems_per_row); + dst[i_dst_row/VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row/VEC_SIZE + col_idx]); +} |
