1#version 450
2
3#extension GL_EXT_control_flow_attributes : enable
4#extension GL_EXT_shader_16bit_storage : enable
5#extension GL_KHR_shader_subgroup_arithmetic : enable
6
7layout (constant_id = 0) const uint BLOCK_SIZE = 128;
8layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
9layout (constant_id = 2) const uint Br = 32;
10layout (constant_id = 3) const uint Bc = 32;
11
12layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
13
14layout (binding = 0) readonly buffer A {float16_t data_a[];};
15layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
16layout (binding = 1) writeonly buffer D {uint data_d[];};
17
18layout (push_constant) uniform parameter {
19 uint nem0;
20 uint nem1;
21 uint nem2;
22 uint nbm1;
23 uint nbm2;
24 uint nbm3;
25 uint nbd1;
26 uint nbd2;
27 uint nbd3;
28};
29
30#define MASK_OPT_ALL_NEG_INF 1
31#define MASK_OPT_ALL_ZERO 2
32
33shared float minsh[NUM_SUBGROUPS];
34shared float maxsh[NUM_SUBGROUPS];
35
36// For each Br x Bc block of the mask (input) buffer, read all values and check
37// if it's all -inf or all zero. Write out a two-bit code indicating which it is
38// (or zero for neither). Each workgroup processes 16 tiles and writes out a
39// 32-bit result mask.
40//
41// TODO: This is a lot of work per workgroup, might make sense to split this into
42// more workgroups in the future.
43void main() {
44 // Each workgroup handles a row
45 const uint tid = gl_LocalInvocationIndex;
46 const uint i0 = gl_WorkGroupID.x;
47 const uint i1 = gl_WorkGroupID.y;
48 const uint i2 = gl_WorkGroupID.z % nem2;
49 const uint i3 = gl_WorkGroupID.z / nem2;
50
51 float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
52
53 uint result = 0;
54
55 // Fast path for fully in-bounds blocks where we can do f16vec4 loads
56 if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
57 ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
58 [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
59 float min_v = FLT_MAX_OVER_2;
60 float max_v = -FLT_MAX_OVER_2;
61 [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
62 uint j0 = (i + tid) % (Bc / 4);
63 uint j1 = (i + tid) / (Bc / 4);
64
65 j0 *= 4;
66 j0 += (i0 * 16 + block_x) * Bc;
67 j1 += i1 * Br;
68
69 vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
70 [[unroll]] for (int c = 0; c < 4; ++c) {
71 min_v = min(min_v, f[c]);
72 max_v = max(max_v, f[c]);
73 }
74 }
75 min_v = subgroupMin(min_v);
76 max_v = subgroupMax(max_v);
77 if (gl_SubgroupInvocationID == 0) {
78 minsh[gl_SubgroupID] = min_v;
79 maxsh[gl_SubgroupID] = max_v;
80 }
81 barrier();
82 if (tid == 0) {
83 [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
84 min_v = min(min_v, minsh[i]);
85 max_v = max(max_v, maxsh[i]);
86 }
87 if (max_v <= -FLT_MAX_OVER_2) {
88 result |= 1 << (2*block_x);
89 }
90 if (min_v == 0.0f && max_v == 0.0f) {
91 result |= 2 << (2*block_x);
92 }
93 }
94 barrier();
95 }
96 } else {
97 [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
98 float min_v = FLT_MAX_OVER_2;
99 float max_v = -FLT_MAX_OVER_2;
100 [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
101 if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
102 continue;
103 }
104 uint j0 = (i + tid) % Bc;
105 uint j1 = (i + tid) / Bc;
106
107 j0 += (i0 * 16 + block_x) * Bc;
108 j1 += i1 * Br;
109
110 if (j0 < nem0 && j1 < nem1) {
111 float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
112 min_v = min(min_v, f);
113 max_v = max(max_v, f);
114 }
115 }
116 min_v = subgroupMin(min_v);
117 max_v = subgroupMax(max_v);
118 if (gl_SubgroupInvocationID == 0) {
119 minsh[gl_SubgroupID] = min_v;
120 maxsh[gl_SubgroupID] = max_v;
121 }
122 barrier();
123 if (tid == 0) {
124 [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
125 min_v = min(min_v, minsh[i]);
126 max_v = max(max_v, maxsh[i]);
127 }
128 if (max_v <= -FLT_MAX_OVER_2) {
129 result |= 1 << (2*block_x);
130 }
131 if (min_v == 0.0f && max_v == 0.0f) {
132 result |= 2 << (2*block_x);
133 }
134 }
135 barrier();
136 }
137 }
138
139 if (tid == 0) {
140 data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
141 }
142}