1#version 450
2
3#extension GL_EXT_control_flow_attributes : require
4
5#include "types.glsl"
6
7layout(constant_id = 0) const uint BLOCK_SIZE = 32;
8
9layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
10
11layout(binding = 0) readonly buffer Src0 { float src0[]; };
12layout(binding = 1) readonly buffer Src1 { float src1[]; };
13layout(binding = 2) buffer Dst { float dst[]; };
14
15layout(push_constant) uniform PushConstants {
16 uint nb01; uint nb02;
17 uint nb11;
18 uint dst_nb0; uint dst_nb1; uint dst_nb2;
19 uint nc; uint ncs; uint nr; uint n_t; uint n_s;
20};
21
22void main() {
23 const uint global_thread_id = gl_GlobalInvocationID.x;
24 const uint i2 = gl_WorkGroupID.y;
25 const uint i3 = gl_WorkGroupID.z;
26
27 if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
28 return;
29 }
30
31 const uint i1 = global_thread_id;
32 const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
33 const uint src1_base = i1 * (nb11 / 4);
34 const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
35
36 float sum = 0.0;
37 [[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
38 const uint src0_idx = src0_base + i0;
39 const uint src1_idx = src1_base + i0;
40 sum += src0[src0_idx] * src1[src1_idx];
41 }
42
43 dst[dst_idx] = sum;
44}