summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl')
-rw-r--r--llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl72
1 files changed, 72 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
new file mode 100644
index 0000000..743004f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl
@@ -0,0 +1,72 @@
+#ifdef MUL_MAT_ID
+shared u16vec2 row_ids[BN];
+uint _ne1;
+
+#ifdef MUL_MAT_ID_USE_SUBGROUPS
+shared uvec4 ballots_sh[NUM_WARPS];
+
+void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
+ _ne1 = 0;
+ uint num_elements = p.nei1 * p.nei0;
+ uint nei0shift = findLSB(p.nei0);
+
+ uint ids[16];
+ uint iter = 0;
+
+ uint expert_count = data_expert_count[expert_idx];
+
+ for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
+ // prefetch up to 16 elements
+ if (iter == 0) {
+ [[unroll]] for (uint k = 0; k < 16; ++k) {
+ uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
+ }
+ }
+ uint i = j + gl_LocalInvocationIndex;
+ bool in_range = i < num_elements;
+ uint ii1;
+ if (nei0_is_pow2) {
+ ii1 = i >> nei0shift;
+ } else {
+ ii1 = i / p.nei0;
+ }
+ uint ii0 = i - ii1 * p.nei0;
+ uint id = ids[iter++];
+ uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
+
+ ballots_sh[gl_SubgroupID] = ballot;
+ barrier();
+
+ uint subgroup_base = 0;
+ uint total = 0;
+ for (uint k = 0; k < gl_NumSubgroups; ++k) {
+ if (k == gl_SubgroupID) {
+ subgroup_base = total;
+ }
+ total += subgroupBallotBitCount(ballots_sh[k]);
+ }
+ barrier();
+
+ uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
+ if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
+ row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
+ }
+ _ne1 += total;
+ iter &= 15;
+ if (_ne1 >= (ic + 1) * BN || _ne1 == expert_count) {
+ break;
+ }
+ }
+ barrier();
+}
+#endif // MUL_MAT_ID_USE_SUBGROUPS
+#endif // MUL_MAT_ID