1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : require
  4#extension GL_EXT_shader_16bit_storage : require
  5
  6#ifdef USE_SUBGROUPS
  7#extension GL_KHR_shader_subgroup_basic : require
  8#extension GL_KHR_shader_subgroup_clustered : require
  9
 10#define INVOCATION_ID gl_SubgroupInvocationID.x
 11#else
 12#define INVOCATION_ID gl_LocalInvocationID.x
 13#endif
 14
 15layout (push_constant) uniform parameter
 16{
 17    uint ne;
 18    uint num_blocks;
 19} p;
 20
 21#include "types.glsl"
 22
 23layout(constant_id = 0) const uint GROUP_SIZE = 32;
 24layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 25
 26layout (binding = 0) readonly buffer A {vec4 data_a[];};
 27#ifndef QBLOCK_X4
 28layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
 29#else
 30layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
 31#endif
 32
 33#ifndef USE_SUBGROUPS
 34shared float shmem[GROUP_SIZE];
 35#endif
 36
 37void quantize(const uint wgid) {
 38    const uint tid = INVOCATION_ID;
 39
 40    // Each thread handles a vec4, so 8 threads handle a block
 41    const uint blocks_per_group = GROUP_SIZE / 8;
 42
 43    const uint block_in_wg = tid / 8;
 44
 45    const uint ib = wgid * blocks_per_group + block_in_wg;
 46    const uint iqs = tid % 8;
 47
 48#ifdef QBLOCK_X4
 49    const uint ibx4_outer = ib / 4;
 50    const uint ibx4_inner = ib % 4;
 51
 52    const uint required_x4_blocks = (p.ne + 127) / 128;
 53    if (ibx4_outer >= required_x4_blocks) {
 54        return;
 55    }
 56#endif
 57
 58    const uint a_idx = ib * 8 + iqs;
 59
 60    vec4 vals = a_idx < p.ne / 4 ? data_a[a_idx] : vec4(0.0f);
 61    const vec4 abs_vals = abs(vals);
 62
 63    // Find absolute max for each block
 64    const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
 65#ifndef USE_SUBGROUPS
 66    shmem[tid] = thread_max;
 67    barrier();
 68    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
 69        if (iqs < s) {
 70            shmem[tid] = max(shmem[tid], shmem[tid + s]);
 71        }
 72        barrier();
 73    }
 74
 75    const float amax = shmem[block_in_wg * 8];
 76#else
 77    const float amax = subgroupClusteredMax(thread_max, 8);
 78#endif
 79
 80    const float d = amax / 127.0;
 81    const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
 82    vals = round(vals * d_inv);
 83
 84#ifndef QBLOCK_X4
 85    data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
 86#else
 87    data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
 88#endif
 89
 90#ifndef USE_SUBGROUPS
 91    barrier();
 92#endif
 93
 94    // Calculate the sum for each block
 95    const float thread_sum = vals.x + vals.y + vals.z + vals.w;
 96#ifndef USE_SUBGROUPS
 97    shmem[tid] = thread_sum;
 98    barrier();
 99    [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
100        if (iqs < s) {
101            shmem[tid] += shmem[tid + s];
102        }
103        barrier();
104    }
105#else
106    const float sum = subgroupClusteredAdd(thread_sum, 8);
107#endif
108    if (iqs == 0) {
109#ifndef USE_SUBGROUPS
110        const float sum = shmem[tid];
111#endif
112
113#ifndef QBLOCK_X4
114        data_b[ib].ds = f16vec2(vec2(d, sum * d));
115#else
116        data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
117#endif
118    }
119}
120
121void main() {
122    uint wgid = gl_WorkGroupID.x;
123    while (wgid < p.num_blocks) {
124        quantize(wgid);
125        wgid += gl_NumWorkGroups.x;
126    }
127}