1#version 450
2
3#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
4
5#include "mul_mat_vec_base.glsl"
6#include "dequant_funcs.glsl"
7
8layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9
10#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
11#define K_PER_ITER 8
12#else
13#define K_PER_ITER 2
14#endif
15
16
17uint a_offset, b_offset, d_offset, y_offset;
18
19void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
20{
21 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
22 const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
23 const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
24 const uint iybs = col - col%QUANT_K; // y block start index
25
26#if K_PER_ITER == 8
27#if QUANT_R == 2
28 const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
29 const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
30 const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
31 const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
32#else
33 const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
34 const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
35#endif
36#else
37 // Check if the second of the pair of elements is OOB, and don't fetch B or
38 // accumulate it. We still fetch a pair of elements for A, which is fine for
39 // quantized formats since they'll be within the same block. We should
40 // probably skip fetching the second element for F16/F32, but as of now we
41 // still do.
42 const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
43
44 FLOAT_TYPE b0 = 0, b1 = 0;
45 b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
46 if (!OOB) {
47 b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
48 }
49#endif
50 uint ibi = first_row*p.ncols;
51 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
52 const uint ib = (ibi + col)/QUANT_K; // block index
53 ibi += p.ncols;
54
55#if K_PER_ITER == 8
56 vec4 v = dequantize4(ib, iqs, a_offset);
57 vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
58
59 const vec2 dm = get_dm(ib, a_offset);
60 if (dm.y != 0) { // quant has min component
61 v = v * dm.x + dm.y;
62 v2 = v2 * dm.x + dm.y;
63 }
64
65 // matrix multiplication
66 FLOAT_TYPE rowtmp = dot(bv0, v);
67 rowtmp += dot(bv1, v2);
68
69 if (dm.y == 0)
70 rowtmp *= dm.x;
71
72 temp[j][n] += rowtmp;
73#else
74 const vec2 v = dequantize(ib, iqs, a_offset);
75
76 // matrix multiplication
77 temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
78 if (!OOB) {
79 temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
80 }
81#endif
82 }
83 }
84}
85
86void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
87 const uint tid = gl_LocalInvocationID.x;
88
89 get_offsets(a_offset, b_offset, d_offset);
90
91 y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
92
93 FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
94
95 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
96 [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
97 temp[j][i] = FLOAT_TYPE(0);
98 }
99 }
100
101 uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
102 if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
103 num_iters++;
104 }
105 int unroll_count = 4;
106 uint unrolled_iters = num_iters & ~(unroll_count - 1);
107
108#if K_PER_ITER == 2
109 // If the K dimension is odd, we need lastiter==true on the last iteration
110 // so OOB is computed correctly. Skip some unrolling to make that happen.
111 if ((p.ncols & 1) != 0 &&
112 unrolled_iters == num_iters &&
113 unrolled_iters > 0) {
114 unrolled_iters -= unroll_count;
115 }
116#endif
117
118 uint i = 0;
119 while (i < unrolled_iters) {
120 // Manually partially unroll the loop
121 [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
122 iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
123 i++;
124 }
125 }
126
127 unroll_count = 2;
128 unrolled_iters = num_iters & ~(unroll_count - 1);
129
130#if K_PER_ITER == 2
131 if ((p.ncols & 1) != 0 &&
132 unrolled_iters == num_iters &&
133 unrolled_iters > 0) {
134 unrolled_iters -= unroll_count;
135 }
136#endif
137
138 while (i < unrolled_iters) {
139 // Manually partially unroll the loop
140 [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
141 iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
142 i++;
143 }
144 }
145 while (i < num_iters) {
146 iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
147 i++;
148 }
149
150 reduce_result(temp, d_offset, first_row, num_rows, tid);
151}
152
153void main() {
154 const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
155
156#ifdef NEEDS_INIT_IQ_SHMEM
157 init_iq_shmem(gl_WorkGroupSize);
158#endif
159
160 // do NUM_ROWS at a time, unless there aren't enough remaining rows
161 if (first_row + NUM_ROWS <= p.stride_d) {
162 compute_outputs(first_row, NUM_ROWS);
163 } else {
164 if (first_row >= p.stride_d) {
165 return;
166 }
167 compute_outputs(first_row, p.stride_d - first_row);
168 }
169}