1#version 450
2
3#include "types.glsl"
4#include "generic_binary_head.glsl"
5
6layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
7
8void main() {
9 const uint i00 = gl_GlobalInvocationID.x;
10
11 if (i00 >= p.ne00) {
12 return;
13 }
14
15 uint gid_z = gl_GlobalInvocationID.z;
16 while (gid_z < p.ne11 * p.ne12) {
17 uint gid_y = gl_GlobalInvocationID.y;
18 while (gid_y < p.ne10) {
19 const uint i10 = gid_y;
20 const uint i11 = gid_z / p.ne12;
21 const uint i12 = gid_z % p.ne12;
22
23 const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
24
25 const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
26 const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
27
28#if defined(DATA_A_BF16)
29 TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
30#else
31 TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]);
32#endif
33#ifndef OPTIMIZATION_ERROR_WORKAROUND
34 data_d[d_offset + i00] = D_TYPE(v);
35#else
36 data_d[d_offset + i00] = D_TYPE(v);
37#endif
38 gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
39 }
40 gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
41 }
42}