1#version 450
  2
  3#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  4#extension GL_EXT_integer_dot_product : require
  5
  6#define MMQ
  7#define B_TYPE block_q8_1_x4
  8
  9#include "mul_mat_vec_base.glsl"
 10
 11layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 12
 13#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
 14#define K_PER_ITER 8
 15#elif defined(DATA_A_QUANT_K)
 16#define K_PER_ITER 16
 17#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
 18#define K_PER_ITER 32
 19#else
 20#error unimplemented
 21#endif
 22
 23uint a_offset, b_offset, d_offset;
 24
 25int32_t cache_b_qs[K_PER_ITER / 4];
 26vec2 cache_b_ds;
 27
 28#include "mul_mat_vecq_funcs.glsl"
 29
 30void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
 31    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
 32        const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
 33
 34        // Preload data_b block
 35        const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
 36        const uint b_qs_idx = tid % (32 / K_PER_ITER);
 37        const uint b_block_idx_outer = b_block_idx / 4;
 38        const uint b_block_idx_inner = b_block_idx % 4;
 39        cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
 40
 41#if QUANT_R == 2
 42        // Assumes K_PER_ITER == 8
 43        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
 44        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
 45#else
 46#if K_PER_ITER == 8
 47        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
 48        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
 49#elif K_PER_ITER == 16
 50        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4    ];
 51        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
 52        cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
 53        cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
 54#elif K_PER_ITER == 32
 55        cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8    ];
 56        cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
 57        cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
 58        cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
 59        cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
 60        cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
 61        cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
 62        cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
 63#else
 64#error unimplemented
 65#endif
 66#endif
 67
 68        uint ibi = first_row*p.ncols;
 69        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 70            const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
 71            ibi += p.ncols;
 72
 73            temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
 74        }
 75    }
 76}
 77
 78void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 79    const uint tid = gl_LocalInvocationID.x;
 80
 81    get_offsets(a_offset, b_offset, d_offset);
 82    a_offset *= QUANT_K / QUANT_K_Q8_1;
 83    b_offset /= QUANT_K_Q8_1;
 84
 85    FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 86
 87    [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
 88        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 89            temp[j][n] = FLOAT_TYPE(0.0f);
 90        }
 91    }
 92
 93    uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
 94    if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
 95        num_iters++;
 96    }
 97    int unroll_count = 4;
 98    uint unrolled_iters = num_iters & ~(unroll_count - 1);
 99
100    uint i = 0;
101    while (i < unrolled_iters) {
102        // Manually partially unroll the loop
103        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
104            iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
105            i++;
106        }
107    }
108
109    unroll_count = 2;
110    unrolled_iters = num_iters & ~(unroll_count - 1);
111
112    while (i < unrolled_iters) {
113        // Manually partially unroll the loop
114        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
115            iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
116            i++;
117        }
118    }
119    while (i < num_iters) {
120        iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
121        i++;
122    }
123
124    reduce_result(temp, d_offset, first_row, num_rows, tid);
125}
126
127void main() {
128    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
129
130#ifdef NEEDS_INIT_IQ_SHMEM
131    init_iq_shmem(gl_WorkGroupSize);
132#endif
133
134    // do NUM_ROWS at a time, unless there aren't enough remaining rows
135    if (first_row + NUM_ROWS <= p.stride_d) {
136        compute_outputs(first_row, NUM_ROWS);
137    } else {
138        if (first_row >= p.stride_d) {
139            return;
140        }
141        compute_outputs(first_row, p.stride_d - first_row);
142    }
143}