1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  3
  4#ifdef cl_qcom_reqd_sub_group_size
  5#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
  6#define ADRENO_GPU 1
  7#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
  8#endif
  9
 10#define QK8_0 32
 11#define N_SIMDGROUP 4
 12
 13#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \
 14    float shared_y; \
 15    char elem; \
 16                                             \
 17    shared_y = sub_group_broadcast(y.s0, 0); \
 18    elem = (char)(bits8.s0 & 0x000000FF); \
 19    total_sums += convert_int(elem) * scale * shared_y; \
 20    shared_y = sub_group_broadcast(y.s1, 0); \
 21    elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \
 22    total_sums += convert_int(elem) * scale * shared_y; \
 23    shared_y = sub_group_broadcast(y.s2, 0); \
 24    elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \
 25    total_sums += convert_int(elem) * scale * shared_y; \
 26    shared_y = sub_group_broadcast(y.s3, 0); \
 27    elem = (char)((bits8.s0 & 0xFF000000) >> 24); \
 28    total_sums += convert_int(elem) * scale * shared_y; \
 29                                             \
 30    shared_y = sub_group_broadcast(y.s4, 0); \
 31    elem = (char)(bits8.s1 & 0x000000FF); \
 32    total_sums += convert_int(elem) * scale * shared_y; \
 33    shared_y = sub_group_broadcast(y.s5, 0); \
 34    elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \
 35    total_sums += convert_int(elem) * scale * shared_y; \
 36    shared_y = sub_group_broadcast(y.s6, 0); \
 37    elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \
 38    total_sums += convert_int(elem) * scale * shared_y; \
 39    shared_y = sub_group_broadcast(y.s7, 0); \
 40    elem = (char)((bits8.s1 & 0xFF000000) >> 24); \
 41    total_sums += convert_int(elem) * scale * shared_y; \
 42                                             \
 43    shared_y = sub_group_broadcast(y.s0, 1); \
 44    elem = (char)(bits8.s2 & 0x000000FF); \
 45    total_sums += convert_int(elem) * scale * shared_y; \
 46    shared_y = sub_group_broadcast(y.s1, 1); \
 47    elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \
 48    total_sums += convert_int(elem) * scale * shared_y; \
 49    shared_y = sub_group_broadcast(y.s2, 1); \
 50    elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \
 51    total_sums += convert_int(elem) * scale * shared_y; \
 52    shared_y = sub_group_broadcast(y.s3, 1); \
 53    elem = (char)((bits8.s2 & 0xFF000000) >> 24); \
 54    total_sums += convert_int(elem) * scale * shared_y; \
 55                                             \
 56    shared_y = sub_group_broadcast(y.s4, 1); \
 57    elem = (char)(bits8.s3 & 0x000000FF); \
 58    total_sums += convert_int(elem) * scale * shared_y; \
 59    shared_y = sub_group_broadcast(y.s5, 1); \
 60    elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \
 61    total_sums += convert_int(elem) * scale * shared_y; \
 62    shared_y = sub_group_broadcast(y.s6, 1); \
 63    elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \
 64    total_sums += convert_int(elem) * scale * shared_y; \
 65    shared_y = sub_group_broadcast(y.s7, 1); \
 66    elem = (char)((bits8.s3 & 0xFF000000) >> 24); \
 67    total_sums += convert_int(elem) * scale * shared_y; \
 68                                             \
 69    shared_y = sub_group_broadcast(y.s0, 2); \
 70    elem = (char)(bits8.s4 & 0x000000FF); \
 71    total_sums += convert_int(elem) * scale * shared_y; \
 72    shared_y = sub_group_broadcast(y.s1, 2); \
 73    elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \
 74    total_sums += convert_int(elem) * scale * shared_y; \
 75    shared_y = sub_group_broadcast(y.s2, 2); \
 76    elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \
 77    total_sums += convert_int(elem) * scale * shared_y; \
 78    shared_y = sub_group_broadcast(y.s3, 2); \
 79    elem = (char)((bits8.s4 & 0xFF000000) >> 24); \
 80    total_sums += convert_int(elem) * scale * shared_y; \
 81                                             \
 82    shared_y = sub_group_broadcast(y.s4, 2); \
 83    elem = (char)(bits8.s5 & 0x000000FF); \
 84    total_sums += convert_int(elem) * scale * shared_y; \
 85    shared_y = sub_group_broadcast(y.s5, 2); \
 86    elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \
 87    total_sums += convert_int(elem) * scale * shared_y; \
 88    shared_y = sub_group_broadcast(y.s6, 2); \
 89    elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \
 90    total_sums += convert_int(elem) * scale * shared_y; \
 91    shared_y = sub_group_broadcast(y.s7, 2); \
 92    elem = (char)((bits8.s5 & 0xFF000000) >> 24); \
 93    total_sums += convert_int(elem) * scale * shared_y; \
 94                                             \
 95    shared_y = sub_group_broadcast(y.s0, 3); \
 96    elem = (char)(bits8.s6 & 0x000000FF); \
 97    total_sums += convert_int(elem) * scale * shared_y; \
 98    shared_y = sub_group_broadcast(y.s1, 3); \
 99    elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \
100    total_sums += convert_int(elem) * scale * shared_y; \
101    shared_y = sub_group_broadcast(y.s2, 3); \
102    elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \
103    total_sums += convert_int(elem) * scale * shared_y; \
104    shared_y = sub_group_broadcast(y.s3, 3); \
105    elem = (char)((bits8.s6 & 0xFF000000) >> 24); \
106    total_sums += convert_int(elem) * scale * shared_y; \
107                                             \
108    shared_y = sub_group_broadcast(y.s4, 3); \
109    elem = (char)(bits8.s7 & 0x000000FF); \
110    total_sums += convert_int(elem) * scale * shared_y; \
111    shared_y = sub_group_broadcast(y.s5, 3); \
112    elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \
113    total_sums += convert_int(elem) * scale * shared_y; \
114    shared_y = sub_group_broadcast(y.s6, 3); \
115    elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \
116    total_sums += convert_int(elem) * scale * shared_y; \
117    shared_y = sub_group_broadcast(y.s7, 3); \
118    elem = (char)((bits8.s7 & 0xFF000000) >> 24); \
119    total_sums += convert_int(elem) * scale * shared_y; \
120
121#ifdef ADRENO_GPU
122REQD_SUBGROUP_SIZE_64
123#endif
124__kernel void kernel_gemv_noshuffle(
125        __read_only  image1d_buffer_t src0_q,  // quantized A
126        global half  * src0_d,  // A scales
127        __read_only  image1d_buffer_t src1,    // B
128        ulong offset1,            // offset to B (0)
129        global float * dst,     // C
130        ulong offsetd,            // offset to C
131        int ne00,               // K
132        int ne01,               // M
133        int ne02,               // 1
134        int ne10,               // K
135        int ne12,               // 1
136        int ne0,                // M
137        int ne1,                // N
138        int r2,                 // 1
139        int r3)
140{
141    uint groupId = get_local_id(1);
142    uint gid     = get_global_id(0);
143    ushort slid    = get_sub_group_local_id();
144
145    uint K = ne00;
146    uint M = ne01;
147
148    uint LINE_STRIDE_A = M;
149    uint BLOCK_STRIDE_A = 8 * M;   // 32 / 4 = 8
150
151    __private uint8     regA;
152    __private half      regS;
153    __private float8    regB;
154
155    __private float totalSum = (float)(0.0f);
156
157    // loop along K in block granularity, skip 4 blocks every iter
158    #pragma unroll 1 /* tell compiler not to unroll */
159    for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) {
160        regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows
161        // first 4 fibers in each wave load 8 B values to its private scope
162        if (slid < 4) {
163            regB.s0123 = read_imagef(src1, (slid * 2 + k * 8));
164            regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8));
165        }
166
167        // load weights for one block in consecutive rows
168        regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x;
169        regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x;
170        regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x;
171        regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x;
172        regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x;
173        regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x;
174        regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x;
175        regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x;
176
177        dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB);
178    }
179
180    // reduction in local memory, assumes #wave=4
181    __local float reduceLM[SIMDGROUP_WIDTH * 3];
182    if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum;
183    if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum;
184    if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum;
185    barrier(CLK_LOCAL_MEM_FENCE);
186    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid];
187    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid];
188    if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid];
189
190    // 1 outputs per fiber in wave 0
191    if (groupId == 0) {
192        dst = (global float*)((global char*)dst + offsetd);
193        dst[gid] = totalSum;
194    }
195}