1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
|
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint NUM_SUBGROUPS = 4;
layout (constant_id = 2) const uint Br = 32;
layout (constant_id = 3) const uint Bc = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float16_t data_a[];};
layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];};
layout (binding = 1) writeonly buffer D {uint data_d[];};
layout (push_constant) uniform parameter {
uint nem0;
uint nem1;
uint nem2;
uint nbm1;
uint nbm2;
uint nbm3;
uint nbd1;
uint nbd2;
uint nbd3;
};
#define MASK_OPT_ALL_NEG_INF 1
#define MASK_OPT_ALL_ZERO 2
shared float minsh[NUM_SUBGROUPS];
shared float maxsh[NUM_SUBGROUPS];
// For each Br x Bc block of the mask (input) buffer, read all values and check
// if it's all -inf or all zero. Write out a two-bit code indicating which it is
// (or zero for neither). Each workgroup processes 16 tiles and writes out a
// 32-bit result mask.
//
// TODO: This is a lot of work per workgroup, might make sense to split this into
// more workgroups in the future.
void main() {
// Each workgroup handles a row
const uint tid = gl_LocalInvocationIndex;
const uint i0 = gl_WorkGroupID.x;
const uint i1 = gl_WorkGroupID.y;
const uint i2 = gl_WorkGroupID.z % nem2;
const uint i3 = gl_WorkGroupID.z / nem2;
float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF);
uint result = 0;
// Fast path for fully in-bounds blocks where we can do f16vec4 loads
if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 &&
((Br * Bc) % (BLOCK_SIZE * 4)) == 0) {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
float min_v = FLT_MAX_OVER_2;
float max_v = -FLT_MAX_OVER_2;
[[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) {
uint j0 = (i + tid) % (Bc / 4);
uint j1 = (i + tid) / (Bc / 4);
j0 *= 4;
j0 += (i0 * 16 + block_x) * Bc;
j1 += i1 * Br;
vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]);
[[unroll]] for (int c = 0; c < 4; ++c) {
min_v = min(min_v, f[c]);
max_v = max(max_v, f[c]);
}
}
min_v = subgroupMin(min_v);
max_v = subgroupMax(max_v);
if (gl_SubgroupInvocationID == 0) {
minsh[gl_SubgroupID] = min_v;
maxsh[gl_SubgroupID] = max_v;
}
barrier();
if (tid == 0) {
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
min_v = min(min_v, minsh[i]);
max_v = max(max_v, maxsh[i]);
}
if (max_v <= -FLT_MAX_OVER_2) {
result |= 1 << (2*block_x);
}
if (min_v == 0.0f && max_v == 0.0f) {
result |= 2 << (2*block_x);
}
}
barrier();
}
} else {
[[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) {
float min_v = FLT_MAX_OVER_2;
float max_v = -FLT_MAX_OVER_2;
[[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) {
if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) {
continue;
}
uint j0 = (i + tid) % Bc;
uint j1 = (i + tid) / Bc;
j0 += (i0 * 16 + block_x) * Bc;
j1 += i1 * Br;
if (j0 < nem0 && j1 < nem1) {
float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]);
min_v = min(min_v, f);
max_v = max(max_v, f);
}
}
min_v = subgroupMin(min_v);
max_v = subgroupMax(max_v);
if (gl_SubgroupInvocationID == 0) {
minsh[gl_SubgroupID] = min_v;
maxsh[gl_SubgroupID] = max_v;
}
barrier();
if (tid == 0) {
[[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) {
min_v = min(min_v, minsh[i]);
max_v = max(max_v, maxsh[i]);
}
if (max_v <= -FLT_MAX_OVER_2) {
result |= 1 << (2*block_x);
}
if (min_v == 0.0f && max_v == 0.0f) {
result |= 2 << (2*block_x);
}
}
barrier();
}
}
if (tid == 0) {
data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result;
}
}
|