aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
blob: ef2f202ec9b6a005b7c862bbc4beba7a6f85f222 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
#version 450

#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable

#include "types.glsl"

#define GATING_FUNC_SOFTMAX 0
#define GATING_FUNC_SIGMOID 1
#define GATING_FUNC_SOFTMAX_WEIGHT 2

layout (push_constant) uniform parameter
{
    uint n_rows;
    uint n_experts_push;
    uint n_expert_used;
    float clamp_min;
    float clamp_max;
    uint gating_func;
    uint has_bias;
    uint with_norm;
    float output_scale;
    float output_bias;
};

layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;

layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts_spec = 512;
layout(constant_id = 2) const bool nexperts_use_push = false;

uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;

#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))

const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);

layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};

const float INFINITY = 1.0 / 0.0;

// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
    float max_val = -INFINITY;

    [[unroll]]
    for (int i = 0; i < experts_per_thread; i++) {
        const uint idx       = lane + i * WARP_SIZE;
        const bool is_active = !use_limit || (idx < limit);
        if (is_active) {
            max_val = max(max_val, vals[i]);
        }
    }

    max_val = subgroupMax(max_val);

    float sum = 0.f;

    [[unroll]]
    for (int i = 0; i < experts_per_thread; i++) {
        const uint idx       = lane + i * WARP_SIZE;
        const bool is_active = !use_limit || (idx < limit);
        if (is_active) {
            const float val = exp(vals[i] - max_val);
            vals[i]         = val;
            sum += val;
        } else {
            vals[i] = 0.f;
        }
    }

    sum = subgroupAdd(sum);

    const float inv_sum = 1.0f / sum;

    [[unroll]]
    for (int i = 0; i < experts_per_thread; i++) {
        const uint idx       = lane + i * WARP_SIZE;
        const bool is_active = !use_limit || (idx < limit);
        if (is_active) {
            vals[i] *= inv_sum;
        }
    }
}

void main() {
    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
    if (row >= n_rows) {
        return;
    }

    const uint logits_offset = n_experts * row;
    const uint bias_offset = 0; // 1D
    const uint weights_offset = n_expert_used * row;
    const uint ids_offset = n_experts * row;
    const uint lane = gl_SubgroupInvocationID;

    float probs[experts_per_thread];
    [[unroll]]
    for (int i = 0; i < experts_per_thread; i++) {
        probs[i] = -INFINITY;
    }

    [[unroll]]
    for (uint i = 0; i < n_experts; i += WARP_SIZE) {
        const uint expert = i + lane;
        probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
    }

    if (gating_func == GATING_FUNC_SOFTMAX) {
        softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
    } else if (gating_func == GATING_FUNC_SIGMOID) {
        [[unroll]]
        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
            const uint expert = i + lane;
            probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
        }
    }

    float selection_probs[experts_per_thread];
    if (has_bias != 0) {
        [[unroll]]
        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
            const uint expert = i + lane;
            selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
        }
    } else {
        [[unroll]]
        for (int i = 0; i < experts_per_thread; i++) {
            selection_probs[i] = probs[i];
        }
    }

    // at this point, each thread holds a portion of softmax,
    // we do the argmax reduce over n_expert_used, each time marking
    // the expert weight as -inf to exclude from the next iteration

    float wt_sum = 0.f;

    float output_weights[experts_per_thread];

    [[unroll]]
    for (int i = 0; i < experts_per_thread; i++) {
        output_weights[i] = 0.f;
    }

    for (int k = 0; k < n_expert_used; k++) {
        float max_val    = probs[0];
        float max_val_s  = selection_probs[0];
        uint   max_expert = lane;

        [[unroll]]
        for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) {
            const uint expert = i + lane;
            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) {
                max_val    = probs[i / WARP_SIZE];
                max_val_s  = selection_probs[i / WARP_SIZE];
                max_expert = expert;
            }
        }

        [[unroll]]
        for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
            const float val    = subgroupShuffleXor(max_val, mask);
            const float val_s  = subgroupShuffleXor(max_val_s, mask);
            const uint  expert = subgroupShuffleXor(max_expert, mask);
            if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
                max_val    = val;
                max_val_s  = val_s;
                max_expert = expert;
            }
        }

        if ((k & (WARP_SIZE - 1)) == lane) {
            output_weights[k / WARP_SIZE] = max_val;
        }

        if ((max_expert & (WARP_SIZE - 1)) == lane) {
            selection_probs[max_expert / WARP_SIZE] = -INFINITY;

            ids[ids_offset + k] = max_expert;
            wt_sum += max_val;
        }
    }

    if (with_norm != 0) {
        wt_sum              = subgroupAdd(wt_sum);
        wt_sum              = clamp(wt_sum, clamp_min, clamp_max);
        const float inv_sum = 1.0f / wt_sum;

        [[unroll]]
        for (uint i = 0; i < experts_per_thread; ++i) {
            output_weights[i] *= inv_sum;
        }
    }

    if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
        softmax_warp_inplace(output_weights, n_expert_used, lane, true);
    }

    [[unroll]]
    for (uint i = 0; i < experts_per_thread; ++i) {
        uint idx = i * WARP_SIZE + lane;
        if (idx < n_expert_used) {
            weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
        }
    }
}