1#version 450
  2
  3#extension GL_EXT_control_flow_attributes : enable
  4#ifdef COOPMAT2
  5#extension GL_NV_cooperative_matrix2 : enable
  6#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  7#extension GL_KHR_memory_scope_semantics : enable
  8#endif
  9
 10#ifdef USE_COLLECTIVES
 11#    extension GL_KHR_shader_subgroup_shuffle : enable
 12#endif
 13
 14#include "types.glsl"
 15
 16// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j
 17layout(binding = 0) readonly buffer A {
 18    A_TYPE knl_data[];
 19};  // src0 - kernel:   [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d
 20
 21layout(binding = 1) readonly buffer B {
 22    B_TYPE src_data[];
 23};  // src1 - input:    [W, H, Cin, N] -- channel_first format
 24
 25layout(binding = 2) writeonly buffer D {
 26    D_TYPE dst_data[];
 27};  // dst - result:    [OW, OH, Cout, N]
 28
 29layout(push_constant) uniform parameter {
 30    // I/O channels, batch size
 31    uint32_t Cout;
 32    uint32_t Cin;
 33    uint32_t N;
 34
 35    // Tensor spatial sizes: input, output
 36    uint32_t W;
 37    uint32_t H;
 38    uint32_t OW;
 39    uint32_t OH;
 40
 41    // Strides in elements
 42    uint32_t nb01;
 43    uint32_t nb02;
 44    uint32_t nb03;
 45
 46    uint32_t nb11;
 47    uint32_t nb12;
 48    uint32_t nb13;
 49
 50    uint32_t nb1;
 51    uint32_t nb2;
 52    uint32_t nb3;
 53
 54    // fastdiv helper values
 55    uint32_t OWmp;   uint32_t OWL;
 56    uint32_t OWOHmp; uint32_t OWOHL;
 57}
 58
 59p;
 60
 61layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 62// Blocktile sizes
 63layout(constant_id = 1) const uint BS_K            = 128;
 64layout(constant_id = 2) const uint BS_CRS          = 16;
 65layout(constant_id = 3) const uint BS_NPQ          = 128;
 66// Thread-tile sizes
 67layout(constant_id = 4) const uint TS_K            = 8;
 68layout(constant_id = 5) const uint use_collectives = 1;
 69layout(constant_id = 6) const uint SHMEM_PAD       = 4;
 70// Stride, padding, dilation
 71layout(constant_id = 7)  const uint s0             = 1;
 72layout(constant_id = 8)  const uint s1             = 1;
 73layout(constant_id = 9)  const uint p0             = 0;
 74layout(constant_id = 10) const uint p1             = 0;
 75layout(constant_id = 11) const uint d0             = 1;
 76layout(constant_id = 12) const uint d1             = 1;
 77// Kernel spatial sizes
 78layout(constant_id = 13) const uint KW             = 1;
 79layout(constant_id = 14) const uint KH             = 1;
 80
 81uint32_t       tid     = gl_LocalInvocationID.x;
 82const uint32_t WG_SIZE = gl_WorkGroupSize.x;
 83
 84uint splitWork(uint work_size, uint block_size) {
 85    return (block_size + work_size - 1) / block_size;
 86}
 87
 88uint32_t K   = p.Cout;
 89uint32_t CRS = p.Cin * KH * KW;
 90uint32_t NPQ = p.N * p.OH * p.OW;
 91
 92uint32_t n_elems_out = K * NPQ;
 93
 94// Number of blocktiles per input
 95uint32_t NB_CRS = splitWork(CRS, BS_CRS);
 96
 97#ifdef COOPMAT2
 98#define SHMEM_TYPE float16_t
 99#else
100#define SHMEM_TYPE float
101#endif
102
103const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
104const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
105
106const uint32_t Ash_numel = BS_K * BS_CRS;
107const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
108
109const uint32_t Ash_len = BS_K * Ash_stride;
110const uint32_t Bsh_len = BS_CRS * Bsh_stride;
111
112shared SHMEM_TYPE Ash[Ash_len];  // K x CRS
113shared SHMEM_TYPE Bsh[Bsh_len];  // CRS x NPQ
114
115// Threadtile sizes
116const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
117
118// Number of threadtiles per blocktile
119const uint32_t NT_K   = BS_K / TS_K;
120const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
121
122/*
123Compute
124KxCRS @ CRSxNPQ = K x NPQ
125K=Cout
126C=Cin
127R,S=KH,KW
128P,Q=OH,OW
129*/
130
131uint32_t B_idx_K   = gl_WorkGroupID.x;
132uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512;
133
134uint32_t T_y = tid / NT_NPQ;
135uint32_t T_x = tid % NT_NPQ;
136
137uint32_t       Ar    = tid / BS_CRS;
138uint32_t       Ac    = tid % BS_CRS;
139const uint32_t ArpWg = WG_SIZE / BS_CRS;
140
141uint32_t       Br    = tid / BS_NPQ;
142uint32_t       Bc    = tid % BS_NPQ;
143const uint32_t BrpWg = WG_SIZE / BS_NPQ;
144
145// see init_fastdiv_values in ggml-vulkan.cpp
146uint fastdiv(uint n, uint mp, uint L) {
147    uint msbs, lsbs;
148    // msbs = mulhi(n, mp)
149    umulExtended(n, mp, msbs, lsbs);
150    return (msbs + n) >> L;
151}
152
153#ifdef COOPMAT2
154#define ACC_TYPE float16_t
155
156ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
157{
158    uint32_t K_idx   = B_idx_K * BS_K + r;
159    uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
160    uint32_t N_idx   = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
161    uint32_t OH_idx  = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
162    uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
163    uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
164    if (K_idx < K && NPQ_idx < NPQ) {
165        dst_data[dst_idx] = D_TYPE(elem);
166    }
167    return elem;
168}
169#endif
170
171void main() {
172    if (B_idx_NPQ * BS_NPQ >= NPQ) {
173        return;
174    }
175
176#ifdef COOPMAT2
177    coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
178    matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
179#else
180    float regC[TS_K][TS_NPQ];
181    for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
182        for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
183            regC[T_ly][T_lx] = 0.0;
184        }
185    }
186#endif
187    /* Advance block in CRS dim */
188    [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
189        uint32_t CRS_idx_a;
190        uint32_t Cin_idx_a;
191        uint32_t KH_idx_a;
192        uint32_t KW_idx_a;
193
194#ifdef USE_COLLECTIVES
195        uint32_t cached_CRS_idx;
196        uint32_t cached_Cin_idx;
197        uint32_t cached_KH_idx;
198        uint32_t cached_KW_idx;
199        if (use_collectives == 1) {
200            cached_CRS_idx                = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID;
201            cached_Cin_idx                = cached_CRS_idx / (KW * KH);
202            uint32_t cached_CRS_remainder = cached_CRS_idx % (KW * KH);
203            cached_KH_idx                 = cached_CRS_remainder / KW;
204            cached_KW_idx                 = cached_CRS_remainder % KW;
205
206            CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac);
207            Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac);
208            KH_idx_a  = subgroupShuffle(cached_KH_idx, Ac);
209            KW_idx_a  = subgroupShuffle(cached_KW_idx, Ac);
210        } else {
211            CRS_idx_a              = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)
212            Cin_idx_a              = CRS_idx_a / (KW * KH);
213            uint32_t CRS_remainder = CRS_idx_a % (KW * KH);
214            KH_idx_a               = CRS_remainder / KW;
215            KW_idx_a               = CRS_remainder % KW;
216        }
217#else
218        CRS_idx_a     = B_idx_CRS * BS_CRS + Ac;  // Global CRS_idx_a (column index of A)
219        Cin_idx_a     = CRS_idx_a / (KW * KH);
220        CRS_remainder = CRS_idx_a % (KW * KH);
221        KH_idx_a      = CRS_remainder / KW;
222        KW_idx_a      = CRS_remainder % KW;
223#endif
224
225        /* Load kernel to A_block: (BS_K x BS_CRS)*/
226        UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) {
227            uint32_t B_ly    = r_offset + Ar;
228            uint32_t B_lx    = Ac;
229            uint32_t K_idx   = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
230#ifdef TRANSPOSE
231            uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
232#else
233            uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
234#endif
235            float    val     = knl_data[knl_idx];
236            if (K_idx >= K || CRS_idx_a >= CRS) {
237                val = 0.0;
238            }
239            Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
240        }
241        /* Load input to B_block: (BS_CRS x BS_NPQ) */
242        UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
243            uint32_t B_ly          = r_offset + Br;             /* Row index of B block */
244            uint32_t B_lx          = Bc;
245            uint32_t NPQ_idx       = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */
246            uint32_t N_idx         = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
247            uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW;
248            uint32_t OH_idx        = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW;
249            uint32_t OW_idx        = NPQ_remainder - OH_idx * p.OW;
250
251            uint32_t CRS_idx_b;
252            uint32_t Cin_idx_b;
253            uint32_t KH_idx_b;
254            uint32_t KW_idx_b;
255#ifdef USE_COLLECTIVES
256            if (use_collectives == 1) {
257                CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br);
258                Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br);
259                KH_idx_b  = subgroupShuffle(cached_KH_idx, r_offset + Br);
260                KW_idx_b  = subgroupShuffle(cached_KW_idx, r_offset + Br);
261            } else {
262                CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
263                Cin_idx_b              = CRS_idx_b / (KW * KH);
264                uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
265                KH_idx_b               = CRS_remainder / KW;
266                KW_idx_b               = CRS_remainder % KW;
267            }
268#else
269            CRS_idx_b              = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */
270            Cin_idx_b              = CRS_idx_b / (KW * KH);
271            uint32_t CRS_remainder = CRS_idx_b % (KW * KH);
272            KH_idx_b               = CRS_remainder / KW;
273            KW_idx_b               = CRS_remainder % KW;
274#endif
275
276#ifdef TRANSPOSE
277            uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * d1 + p1;
278            uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * d0 + p0;
279            uint32_t H_idx = H_idx_x_s1 / s1;
280            uint32_t W_idx = W_idx_x_s0 / s0;
281#else
282            uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1;
283            uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0;
284#endif
285            uint32_t src_idx =
286                min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
287            float val = src_data[src_idx];
288            if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
289                || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case)
290#ifdef TRANSPOSE
291                || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0)
292#endif
293                ) {
294                val = 0.0;
295            }
296            Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
297        }
298        barrier();
299#ifdef COOPMAT2
300        coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
301        coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
302
303        coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
304        coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
305        matC = coopMatMulAdd(matA, matB, matC);
306#else
307        if (T_y * TS_K < K) {
308            UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
309                float regA[TS_K];
310                float regB[TS_NPQ];
311                for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
312                    regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
313                }
314                for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
315                    regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx];
316                }
317                for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
318                    for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
319                        regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]);
320                    }
321                }
322            }
323        }
324#endif
325        barrier();
326    }
327    /* Save C* */
328#ifdef COOPMAT2
329    coopMatPerElementNV(matC, matC, perElemOpStore);
330#else
331    if (T_y * TS_K < K) {
332        for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
333            for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
334                uint32_t K_idx   = B_idx_K * BS_K + T_y * TS_K + T_ly;
335                uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx;
336                uint32_t N_idx   = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
337                uint32_t OH_idx  = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
338                uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
339                uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
340                if (K_idx < K && NPQ_idx < NPQ) {
341                    dst_data[dst_idx] = regC[T_ly][T_lx];
342                }
343            }
344        }
345    }
346#endif
347}