1#version 450
 2
 3#extension GL_EXT_control_flow_attributes : require
 4
 5#include "types.glsl"
 6
 7layout(constant_id = 0) const uint BLOCK_SIZE = 32;
 8
 9layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10
11layout(binding = 0) readonly buffer Src0 { float src0[]; };
12layout(binding = 1) readonly buffer Src1 { float src1[]; };
13layout(binding = 2) buffer Dst { float dst[]; };
14
15layout(push_constant) uniform PushConstants {
16    uint nb01; uint nb02;
17    uint nb11;
18    uint dst_nb0; uint dst_nb1; uint dst_nb2;
19    uint nc; uint ncs; uint nr; uint n_t; uint n_s;
20};
21
22void main() {
23    const uint global_thread_id = gl_GlobalInvocationID.x;
24    const uint i2 = gl_WorkGroupID.y;
25    const uint i3 = gl_WorkGroupID.z;
26
27    if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
28        return;
29    }
30
31    const uint i1 = global_thread_id;
32    const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
33    const uint src1_base = i1 * (nb11 / 4);
34    const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
35
36    float sum = 0.0;
37    [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
38        const uint src0_idx = src0_base + i0;
39        const uint src1_idx = src1_base + i0;
40        sum += src0[src0_idx] * src1[src1_idx];
41    }
42
43    dst[dst_idx] = sum;
44}