1#version 450
2
3#extension GL_EXT_control_flow_attributes : require
4#extension GL_EXT_shader_16bit_storage : require
5
6#ifdef USE_SUBGROUPS
7#extension GL_KHR_shader_subgroup_basic : require
8#extension GL_KHR_shader_subgroup_clustered : require
9
10#define INVOCATION_ID gl_SubgroupInvocationID.x
11#else
12#define INVOCATION_ID gl_LocalInvocationID.x
13#endif
14
15layout (push_constant) uniform parameter
16{
17 uint ne;
18 uint num_blocks;
19} p;
20
21#include "types.glsl"
22
23layout(constant_id = 0) const uint GROUP_SIZE = 32;
24layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
25
26layout (binding = 0) readonly buffer A {vec4 data_a[];};
27#ifndef QBLOCK_X4
28layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
29#else
30layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
31#endif
32
33#ifndef USE_SUBGROUPS
34shared float shmem[GROUP_SIZE];
35#endif
36
37void quantize(const uint wgid) {
38 const uint tid = INVOCATION_ID;
39
40 // Each thread handles a vec4, so 8 threads handle a block
41 const uint blocks_per_group = GROUP_SIZE / 8;
42
43 const uint block_in_wg = tid / 8;
44
45 const uint ib = wgid * blocks_per_group + block_in_wg;
46 const uint iqs = tid % 8;
47
48#ifdef QBLOCK_X4
49 const uint ibx4_outer = ib / 4;
50 const uint ibx4_inner = ib % 4;
51
52 const uint required_x4_blocks = (p.ne + 127) / 128;
53 if (ibx4_outer >= required_x4_blocks) {
54 return;
55 }
56#endif
57
58 const uint a_idx = ib * 8 + iqs;
59
60 vec4 vals = a_idx < p.ne / 4 ? data_a[a_idx] : vec4(0.0f);
61 const vec4 abs_vals = abs(vals);
62
63 // Find absolute max for each block
64 const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
65#ifndef USE_SUBGROUPS
66 shmem[tid] = thread_max;
67 barrier();
68 [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
69 if (iqs < s) {
70 shmem[tid] = max(shmem[tid], shmem[tid + s]);
71 }
72 barrier();
73 }
74
75 const float amax = shmem[block_in_wg * 8];
76#else
77 const float amax = subgroupClusteredMax(thread_max, 8);
78#endif
79
80 const float d = amax / 127.0;
81 const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
82 vals = round(vals * d_inv);
83
84#ifndef QBLOCK_X4
85 data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
86#else
87 data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
88#endif
89
90#ifndef USE_SUBGROUPS
91 barrier();
92#endif
93
94 // Calculate the sum for each block
95 const float thread_sum = vals.x + vals.y + vals.z + vals.w;
96#ifndef USE_SUBGROUPS
97 shmem[tid] = thread_sum;
98 barrier();
99 [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
100 if (iqs < s) {
101 shmem[tid] += shmem[tid + s];
102 }
103 barrier();
104 }
105#else
106 const float sum = subgroupClusteredAdd(thread_sum, 8);
107#endif
108 if (iqs == 0) {
109#ifndef USE_SUBGROUPS
110 const float sum = shmem[tid];
111#endif
112
113#ifndef QBLOCK_X4
114 data_b[ib].ds = f16vec2(vec2(d, sum * d));
115#else
116 data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
117#endif
118 }
119}
120
121void main() {
122 uint wgid = gl_WorkGroupID.x;
123 while (wgid < p.num_blocks) {
124 quantize(wgid);
125 wgid += gl_NumWorkGroups.x;
126 }
127}