1#version 450
 2
 3#include "types.glsl"
 4#include "generic_binary_head.glsl"
 5
 6layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 7
 8void main() {
 9    const uint i00 = gl_GlobalInvocationID.x;
10
11    if (i00 >= p.ne00) {
12        return;
13    }
14
15    uint gid_z = gl_GlobalInvocationID.z;
16    while (gid_z < p.ne11 * p.ne12) {
17        uint gid_y = gl_GlobalInvocationID.y;
18        while (gid_y < p.ne10) {
19            const uint i10 = gid_y;
20            const uint i11 = gid_z / p.ne12;
21            const uint i12 = gid_z % p.ne12;
22
23            const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
24
25            const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
26            const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
27
28#if defined(DATA_A_BF16)
29            TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
30#else
31            TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]);
32#endif
33#ifndef OPTIMIZATION_ERROR_WORKAROUND
34            data_d[d_offset + i00] = D_TYPE(v);
35#else
36            data_d[d_offset + i00] = D_TYPE(v);
37#endif
38            gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
39        }
40        gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
41    }
42}