1#version 450
2
3#extension GL_EXT_control_flow_attributes : require
4#extension GL_KHR_shader_subgroup_basic : enable
5#extension GL_KHR_shader_subgroup_arithmetic : enable
6#extension GL_KHR_shader_subgroup_shuffle : enable
7
8#include "types.glsl"
9
10#define GATING_FUNC_SOFTMAX 0
11#define GATING_FUNC_SIGMOID 1
12#define GATING_FUNC_SOFTMAX_WEIGHT 2
13
14layout (push_constant) uniform parameter
15{
16 uint n_rows;
17 uint n_experts_push;
18 uint n_expert_used;
19 float clamp_min;
20 float clamp_max;
21 uint gating_func;
22 uint has_bias;
23 uint with_norm;
24 float output_scale;
25 float output_bias;
26};
27
28layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
29
30layout(constant_id = 0) const uint WARP_SIZE = 32;
31layout(constant_id = 1) const uint n_experts_spec = 512;
32layout(constant_id = 2) const bool nexperts_use_push = false;
33
34uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
35
36#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
37
38const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
39
40layout (binding = 0, std430) readonly buffer Logits {float logits[];};
41layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
42layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
43layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
44
45const float INFINITY = 1.0 / 0.0;
46
47// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
48void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
49 float max_val = -INFINITY;
50
51 [[unroll]]
52 for (int i = 0; i < experts_per_thread; i++) {
53 const uint idx = lane + i * WARP_SIZE;
54 const bool is_active = !use_limit || (idx < limit);
55 if (is_active) {
56 max_val = max(max_val, vals[i]);
57 }
58 }
59
60 max_val = subgroupMax(max_val);
61
62 float sum = 0.f;
63
64 [[unroll]]
65 for (int i = 0; i < experts_per_thread; i++) {
66 const uint idx = lane + i * WARP_SIZE;
67 const bool is_active = !use_limit || (idx < limit);
68 if (is_active) {
69 const float val = exp(vals[i] - max_val);
70 vals[i] = val;
71 sum += val;
72 } else {
73 vals[i] = 0.f;
74 }
75 }
76
77 sum = subgroupAdd(sum);
78
79 const float inv_sum = 1.0f / sum;
80
81 [[unroll]]
82 for (int i = 0; i < experts_per_thread; i++) {
83 const uint idx = lane + i * WARP_SIZE;
84 const bool is_active = !use_limit || (idx < limit);
85 if (is_active) {
86 vals[i] *= inv_sum;
87 }
88 }
89}
90
91void main() {
92 const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
93 if (row >= n_rows) {
94 return;
95 }
96
97 const uint logits_offset = n_experts * row;
98 const uint bias_offset = 0; // 1D
99 const uint weights_offset = n_expert_used * row;
100 const uint ids_offset = n_experts * row;
101 const uint lane = gl_SubgroupInvocationID;
102
103 float probs[experts_per_thread];
104 [[unroll]]
105 for (int i = 0; i < experts_per_thread; i++) {
106 probs[i] = -INFINITY;
107 }
108
109 [[unroll]]
110 for (uint i = 0; i < n_experts; i += WARP_SIZE) {
111 const uint expert = i + lane;
112 probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
113 }
114
115 if (gating_func == GATING_FUNC_SOFTMAX) {
116 softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
117 } else if (gating_func == GATING_FUNC_SIGMOID) {
118 [[unroll]]
119 for (uint i = 0; i < n_experts; i += WARP_SIZE) {
120 const uint expert = i + lane;
121 probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
122 }
123 }
124
125 float selection_probs[experts_per_thread];
126 if (has_bias != 0) {
127 [[unroll]]
128 for (uint i = 0; i < n_experts; i += WARP_SIZE) {
129 const uint expert = i + lane;
130 selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
131 }
132 } else {
133 [[unroll]]
134 for (int i = 0; i < experts_per_thread; i++) {
135 selection_probs[i] = probs[i];
136 }
137 }
138
139 // at this point, each thread holds a portion of softmax,
140 // we do the argmax reduce over n_expert_used, each time marking
141 // the expert weight as -inf to exclude from the next iteration
142
143 float wt_sum = 0.f;
144
145 float output_weights[experts_per_thread];
146
147 [[unroll]]
148 for (int i = 0; i < experts_per_thread; i++) {
149 output_weights[i] = 0.f;
150 }
151
152 for (int k = 0; k < n_expert_used; k++) {
153 float max_val = probs[0];
154 float max_val_s = selection_probs[0];
155 uint max_expert = lane;
156
157 [[unroll]]
158 for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
159 const uint expert = i + lane;
160 if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
161 max_val = probs[i / WARP_SIZE];
162 max_val_s = selection_probs[i / WARP_SIZE];
163 max_expert = expert;
164 }
165 }
166
167 [[unroll]]
168 for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
169 const float val = subgroupShuffleXor(max_val, mask);
170 const float val_s = subgroupShuffleXor(max_val_s, mask);
171 const uint expert = subgroupShuffleXor(max_expert, mask);
172 if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
173 max_val = val;
174 max_val_s = val_s;
175 max_expert = expert;
176 }
177 }
178
179 if ((k & (WARP_SIZE - 1)) == lane) {
180 output_weights[k / WARP_SIZE] = max_val;
181 }
182
183 if ((max_expert & (WARP_SIZE - 1)) == lane) {
184 selection_probs[max_expert / WARP_SIZE] = -INFINITY;
185
186 ids[ids_offset + k] = max_expert;
187 wt_sum += max_val;
188 }
189 }
190
191 if (with_norm != 0) {
192 wt_sum = subgroupAdd(wt_sum);
193 wt_sum = clamp(wt_sum, clamp_min, clamp_max);
194 const float inv_sum = 1.0f / wt_sum;
195
196 [[unroll]]
197 for (uint i = 0; i < experts_per_thread; ++i) {
198 output_weights[i] *= inv_sum;
199 }
200 }
201
202 if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
203 softmax_warp_inplace(output_weights, n_expert_used, lane, true);
204 }
205
206 [[unroll]]
207 for (uint i = 0; i < experts_per_thread; ++i) {
208 uint idx = i * WARP_SIZE + lane;
209 if (idx < n_expert_used) {
210 weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
211 }
212 }
213}