1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#extension GL_EXT_shader_16bit_storage : enable
  5#extension GL_KHR_shader_subgroup_arithmetic : enable
  6
  7layout (constant_id = 0) const uint BLOCK_SIZE = 128;
  8layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
  9layout (constant_id = 2) const uint Br = 32;
 10layout (constant_id = 3) const uint Bc = 32;
 11
 12layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 13
 14layout (binding = 0) readonly buffer A {float16_t data_a[];};
 15layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
 16layout (binding = 1) writeonly buffer D {uint data_d[];};
 17
 18layout (push_constant) uniform parameter {
 19    uint nem0;
 20    uint nem1;
 21    uint nem2;
 22    uint nbm1;
 23    uint nbm2;
 24    uint nbm3;
 25    uint nbd1;
 26    uint nbd2;
 27    uint nbd3;
 28};
 29
 30#define MASK_OPT_ALL_NEG_INF 1
 31#define MASK_OPT_ALL_ZERO 2
 32
 33shared float minsh[NUM_SUBGROUPS];
 34shared float maxsh[NUM_SUBGROUPS];
 35
 36// For each Br x Bc block of the mask (input) buffer, read all values and check
 37// if it's all -inf or all zero. Write out a two-bit code indicating which it is
 38// (or zero for neither). Each workgroup processes 16 tiles and writes out a
 39// 32-bit result mask.
 40//
 41// TODO: This is a lot of work per workgroup, might make sense to split this into
 42// more workgroups in the future.
 43void main() {
 44    // Each workgroup handles a row
 45    const uint tid = gl_LocalInvocationIndex;
 46    const uint i0 = gl_WorkGroupID.x;
 47    const uint i1 = gl_WorkGroupID.y;
 48    const uint i2 = gl_WorkGroupID.z % nem2;
 49    const uint i3 = gl_WorkGroupID.z / nem2;
 50
 51    float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
 52
 53    uint result = 0;
 54
 55    // Fast path for fully in-bounds blocks where we can do f16vec4 loads
 56    if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
 57        ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
 58        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
 59            float min_v = FLT_MAX_OVER_2;
 60            float max_v = -FLT_MAX_OVER_2;
 61            [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
 62                uint j0 = (i + tid) % (Bc / 4);
 63                uint j1 = (i + tid) / (Bc / 4);
 64
 65                j0 *= 4;
 66                j0 += (i0 * 16 + block_x) * Bc;
 67                j1 += i1 * Br;
 68
 69                vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
 70                [[unroll]] for (int c = 0; c < 4; ++c) {
 71                    min_v = min(min_v, f[c]);
 72                    max_v = max(max_v, f[c]);
 73                }
 74            }
 75            min_v = subgroupMin(min_v);
 76            max_v = subgroupMax(max_v);
 77            if (gl_SubgroupInvocationID == 0) {
 78                minsh[gl_SubgroupID] = min_v;
 79                maxsh[gl_SubgroupID] = max_v;
 80            }
 81            barrier();
 82            if (tid == 0) {
 83                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
 84                    min_v = min(min_v, minsh[i]);
 85                    max_v = max(max_v, maxsh[i]);
 86                }
 87                if (max_v <= -FLT_MAX_OVER_2) {
 88                    result |= 1 << (2*block_x);
 89                }
 90                if (min_v == 0.0f && max_v == 0.0f) {
 91                    result |= 2 << (2*block_x);
 92                }
 93            }
 94            barrier();
 95        }
 96    } else {
 97        [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
 98            float min_v = FLT_MAX_OVER_2;
 99            float max_v = -FLT_MAX_OVER_2;
100            [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
101                if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
102                    continue;
103                }
104                uint j0 = (i + tid) % Bc;
105                uint j1 = (i + tid) / Bc;
106
107                j0 += (i0 * 16 + block_x) * Bc;
108                j1 += i1 * Br;
109
110                if (j0 < nem0 && j1 < nem1) {
111                    float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
112                    min_v = min(min_v, f);
113                    max_v = max(max_v, f);
114                }
115            }
116            min_v = subgroupMin(min_v);
117            max_v = subgroupMax(max_v);
118            if (gl_SubgroupInvocationID == 0) {
119                minsh[gl_SubgroupID] = min_v;
120                maxsh[gl_SubgroupID] = max_v;
121            }
122            barrier();
123            if (tid == 0) {
124                [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
125                    min_v = min(min_v, minsh[i]);
126                    max_v = max(max_v, maxsh[i]);
127                }
128                if (max_v <= -FLT_MAX_OVER_2) {
129                    result |= 1 << (2*block_x);
130                }
131                if (min_v == 0.0f && max_v == 0.0f) {
132                    result |= 2 << (2*block_x);
133                }
134            }
135            barrier();
136        }
137    }
138
139    if (tid == 0) {
140        data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
141    }
142}