1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : require
  4#extension GL_KHR_shader_subgroup_basic : enable
  5#extension GL_KHR_shader_subgroup_arithmetic : enable
  6#extension GL_KHR_shader_subgroup_shuffle : enable
  7
  8#include "types.glsl"
  9
 10#define GATING_FUNC_SOFTMAX 0
 11#define GATING_FUNC_SIGMOID 1
 12#define GATING_FUNC_SOFTMAX_WEIGHT 2
 13
 14layout (push_constant) uniform parameter
 15{
 16    uint n_rows;
 17    uint n_experts_push;
 18    uint n_expert_used;
 19    float clamp_min;
 20    float clamp_max;
 21    uint gating_func;
 22    uint has_bias;
 23    uint with_norm;
 24    float output_scale;
 25    float output_bias;
 26};
 27
 28layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
 29
 30layout(constant_id = 0) const uint WARP_SIZE = 32;
 31layout(constant_id = 1) const uint n_experts_spec = 512;
 32layout(constant_id = 2) const bool nexperts_use_push = false;
 33
 34uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
 35
 36#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
 37
 38const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
 39
 40layout (binding = 0, std430) readonly buffer Logits {float logits[];};
 41layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
 42layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
 43layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
 44
 45const float INFINITY = 1.0 / 0.0;
 46
 47// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
 48void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
 49    float max_val = -INFINITY;
 50
 51    [[unroll]]
 52    for (int i = 0; i < experts_per_thread; i++) {
 53        const uint idx       = lane + i * WARP_SIZE;
 54        const bool is_active = !use_limit || (idx < limit);
 55        if (is_active) {
 56            max_val = max(max_val, vals[i]);
 57        }
 58    }
 59
 60    max_val = subgroupMax(max_val);
 61
 62    float sum = 0.f;
 63
 64    [[unroll]]
 65    for (int i = 0; i < experts_per_thread; i++) {
 66        const uint idx       = lane + i * WARP_SIZE;
 67        const bool is_active = !use_limit || (idx < limit);
 68        if (is_active) {
 69            const float val = exp(vals[i] - max_val);
 70            vals[i]         = val;
 71            sum += val;
 72        } else {
 73            vals[i] = 0.f;
 74        }
 75    }
 76
 77    sum = subgroupAdd(sum);
 78
 79    const float inv_sum = 1.0f / sum;
 80
 81    [[unroll]]
 82    for (int i = 0; i < experts_per_thread; i++) {
 83        const uint idx       = lane + i * WARP_SIZE;
 84        const bool is_active = !use_limit || (idx < limit);
 85        if (is_active) {
 86            vals[i] *= inv_sum;
 87        }
 88    }
 89}
 90
 91void main() {
 92    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
 93    if (row >= n_rows) {
 94        return;
 95    }
 96
 97    const uint logits_offset = n_experts * row;
 98    const uint bias_offset = 0; // 1D
 99    const uint weights_offset = n_expert_used * row;
100    const uint ids_offset = n_experts * row;
101    const uint lane = gl_SubgroupInvocationID;
102
103    float probs[experts_per_thread];
104    [[unroll]]
105    for (int i = 0; i < experts_per_thread; i++) {
106        probs[i] = -INFINITY;
107    }
108
109    [[unroll]]
110    for (uint i = 0; i < n_experts; i += WARP_SIZE) {
111        const uint expert = i + lane;
112        probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
113    }
114
115    if (gating_func == GATING_FUNC_SOFTMAX) {
116        softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
117    } else if (gating_func == GATING_FUNC_SIGMOID) {
118        [[unroll]]
119        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
120            const uint expert = i + lane;
121            probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
122        }
123    }
124
125    float selection_probs[experts_per_thread];
126    if (has_bias != 0) {
127        [[unroll]]
128        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
129            const uint expert = i + lane;
130            selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
131        }
132    } else {
133        [[unroll]]
134        for (int i = 0; i < experts_per_thread; i++) {
135            selection_probs[i] = probs[i];
136        }
137    }
138
139    // at this point, each thread holds a portion of softmax,
140    // we do the argmax reduce over n_expert_used, each time marking
141    // the expert weight as -inf to exclude from the next iteration
142
143    float wt_sum = 0.f;
144
145    float output_weights[experts_per_thread];
146
147    [[unroll]]
148    for (int i = 0; i < experts_per_thread; i++) {
149        output_weights[i] = 0.f;
150    }
151
152    for (int k = 0; k < n_expert_used; k++) {
153        float max_val    = probs[0];
154        float max_val_s  = selection_probs[0];
155        uint   max_expert = lane;
156
157        [[unroll]]
158        for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
159            const uint expert = i + lane;
160            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
161                max_val    = probs[i / WARP_SIZE];
162                max_val_s  = selection_probs[i / WARP_SIZE];
163                max_expert = expert;
164            }
165        }
166
167        [[unroll]]
168        for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
169            const float val    = subgroupShuffleXor(max_val, mask);
170            const float val_s  = subgroupShuffleXor(max_val_s, mask);
171            const uint  expert = subgroupShuffleXor(max_expert, mask);
172            if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
173                max_val    = val;
174                max_val_s  = val_s;
175                max_expert = expert;
176            }
177        }
178
179        if ((k & (WARP_SIZE - 1)) == lane) {
180            output_weights[k / WARP_SIZE] = max_val;
181        }
182
183        if ((max_expert & (WARP_SIZE - 1)) == lane) {
184            selection_probs[max_expert / WARP_SIZE] = -INFINITY;
185
186            ids[ids_offset + k] = max_expert;
187            wt_sum += max_val;
188        }
189    }
190
191    if (with_norm != 0) {
192        wt_sum              = subgroupAdd(wt_sum);
193        wt_sum              = clamp(wt_sum, clamp_min, clamp_max);
194        const float inv_sum = 1.0f / wt_sum;
195
196        [[unroll]]
197        for (uint i = 0; i < experts_per_thread; ++i) {
198            output_weights[i] *= inv_sum;
199        }
200    }
201
202    if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
203        softmax_warp_inplace(output_weights, n_expert_used, lane, true);
204    }
205
206    [[unroll]]
207    for (uint i = 0; i < experts_per_thread; ++i) {
208        uint idx = i * WARP_SIZE + lane;
209        if (idx < n_expert_used) {
210            weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
211        }
212    }
213}