1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
  2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
  3
  4#define LM_FIRST_256B   0
  5#define LM_SECOND_256B  64
  6#define LM_THIRD_256B   128
  7#define LM_FOURTH_256B  192
  8
  9
 10inline float16 mm_load_a(
 11    image1d_buffer_t matrix_A,
 12    uint subMatrixAStartInElements,
 13    int nb01,
 14    int line_stride_matrix_A_in_bytes
 15) {
 16    __private float8 regA;
 17    size_t sub_block_id_m = get_local_id(0);
 18
 19#ifdef KQV
 20    uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
 21#else // KQ
 22    uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
 23#endif
 24
 25    regA.s0123  = read_imagef(matrix_A, a_texCoord/4);
 26    regA.s4567  = read_imagef(matrix_A, (a_texCoord+4)/4);
 27
 28    return convert_float16(as_half16(regA));
 29}
 30
 31inline float4 alu_32(
 32    float16 regA,
 33    __local float4* matrix_B_vec
 34) {
 35
 36    __private float4 rC = 0;
 37    int i = get_sub_group_id() * 64;
 38
 39    rC += regA.s0  * matrix_B_vec[i];
 40    rC += regA.s1  * matrix_B_vec[i + 16];
 41    rC += regA.s4  * matrix_B_vec[i + 1];
 42    rC += regA.s5  * matrix_B_vec[i + 17];
 43    rC += regA.s8  * matrix_B_vec[i + 2];
 44    rC += regA.s9  * matrix_B_vec[i + 18];
 45    rC += regA.sc  * matrix_B_vec[i + 3];
 46    rC += regA.sd  * matrix_B_vec[i + 19];
 47
 48    i += 32;
 49
 50    rC += regA.s2  * matrix_B_vec[i];
 51     rC += regA.s3  * matrix_B_vec[i + 16];
 52    rC += regA.s6  * matrix_B_vec[i + 1];
 53    rC += regA.s7  * matrix_B_vec[i + 17];
 54    rC += regA.sa  * matrix_B_vec[i + 2];
 55    rC += regA.sb  * matrix_B_vec[i + 18];
 56    rC += regA.se  * matrix_B_vec[i + 3];
 57    rC += regA.sf  * matrix_B_vec[i + 19];
 58
 59    return rC;
 60}
 61
 62inline float16 alu_16(
 63    float16 regA,
 64    __local float* matrix_B_local
 65) {
 66    float16 out;
 67    __local float4* matrix_B_vec = (__local float4*)matrix_B_local;
 68
 69    out.s0123 = alu_32(regA, matrix_B_vec);
 70    out.s4567 = alu_32(regA, matrix_B_vec + 4);
 71    out.s89ab = alu_32(regA, matrix_B_vec + 8);
 72    out.scdef = alu_32(regA, matrix_B_vec + 12);
 73
 74    return out;
 75}
 76
 77inline void mm_mad(
 78    __local float* matrix_B_local,
 79    float16 regA,
 80    float8 regB,
 81    uint b_localOffsetInWords,
 82    float16* regC0_ptr,
 83    float16* regC1_ptr
 84) {
 85    int offset = b_localOffsetInWords + get_sub_group_id() * 256;
 86
 87    matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
 88    matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
 89    matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
 90    matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
 91
 92    float16 add0 = alu_16(regA, matrix_B_local);
 93    *regC0_ptr += add0;
 94
 95    matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
 96    matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
 97    matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
 98    matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
 99
100    float16 add1 = alu_16(regA, matrix_B_local);
101    *regC1_ptr += add1;
102}
103
104inline void mm_store_c_N(
105    __write_only image1d_buffer_t matrix_C,
106    float16 regC0,
107    float16 regC1,
108    uint subMatrixCStartInElements,
109    int line_stride_matrix_C_in_bytes,
110    int mask
111) {
112    size_t sub_block_id_m = get_local_id(0);
113
114    uint strideInWords     = line_stride_matrix_C_in_bytes/4;
115    uint c_coordInWords_0  = (subMatrixCStartInElements + sub_block_id_m);
116
117    uint c_coordInWords_1  = c_coordInWords_0 + 1  * strideInWords;
118    uint c_coordInWords_2  = c_coordInWords_0 + 2  * strideInWords;
119    uint c_coordInWords_3  = c_coordInWords_0 + 3  * strideInWords;
120    uint c_coordInWords_4  = c_coordInWords_0 + 4  * strideInWords;
121    uint c_coordInWords_5  = c_coordInWords_0 + 5  * strideInWords;
122    uint c_coordInWords_6  = c_coordInWords_0 + 6  * strideInWords;
123    uint c_coordInWords_7  = c_coordInWords_0 + 7  * strideInWords;
124    uint c_coordInWords_8  = c_coordInWords_0 + 8  * strideInWords;
125    uint c_coordInWords_9  = c_coordInWords_0 + 9  * strideInWords;
126    uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
127    uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
128    uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
129    uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
130    uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
131    uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
132    uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
133    uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
134    uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
135    uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
136    uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
137    uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
138    uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
139    uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
140    uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
141    uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
142    uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
143    uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
144    uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
145    uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
146    uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
147    uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
148
149    if (mask > 0)  { write_imagef(matrix_C, c_coordInWords_0, regC0.s0);  }
150    if (mask > 1)  { write_imagef(matrix_C, c_coordInWords_1, regC0.s1);  }
151    if (mask > 2)  { write_imagef(matrix_C, c_coordInWords_2, regC0.s2);  }
152    if (mask > 3)  { write_imagef(matrix_C, c_coordInWords_3, regC0.s3);  }
153    if (mask > 4)  { write_imagef(matrix_C, c_coordInWords_4, regC0.s4);  }
154    if (mask > 5)  { write_imagef(matrix_C, c_coordInWords_5, regC0.s5);  }
155    if (mask > 6)  { write_imagef(matrix_C, c_coordInWords_6, regC0.s6);  }
156    if (mask > 7)  { write_imagef(matrix_C, c_coordInWords_7, regC0.s7);  }
157    if (mask > 8)  { write_imagef(matrix_C, c_coordInWords_8, regC0.s8);  }
158    if (mask > 9)  { write_imagef(matrix_C, c_coordInWords_9, regC0.s9);  }
159    if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
160    if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
161    if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
162    if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
163    if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
164    if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
165    if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
166    if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
167    if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
168    if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
169    if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
170    if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
171    if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
172    if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
173    if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
174    if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
175    if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
176    if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
177    if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
178    if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
179    if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
180    if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
181}
182
183#define TILESIZE_K 16
184#define TILESIZE_M 64
185#define TILESIZE_N 32
186#ifdef KQV
187__kernel void mul_mm_f16_f32_kqv(
188#else
189__kernel void mul_mm_f16_f32_kq(
190#endif
191        __read_only  image1d_buffer_t matrix_A,
192        int offset0,
193        __global float* matrix_B,
194        int offset1,
195        __write_only image1d_buffer_t matrix_C,
196        int offsetd,
197        int M, int K, int N,
198        int D_A,
199        int D_B,
200        int nb01
201) {
202
203    uint block_id_m = get_global_id(1);
204    uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
205    uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
206
207    __private float16  regA;
208    __private float8   regB;
209    __private float16 regC0;
210    __private float16 regC1;
211
212    const uint col   = block_id_m * TILESIZE_M;
213    const uint row   = block_id_n * TILESIZE_N;
214    const uint depth_A = block_id_d / (D_B/D_A);
215    const uint depth_B = block_id_d;
216
217#ifdef KQV
218    int line_stride_matrix_A_in_bytes = nb01 * M;
219    int line_stride_matrix_B_in_bytes = K * N * 4;
220#else
221    int line_stride_matrix_A_in_bytes = K * D_A * 2;
222    int line_stride_matrix_B_in_bytes = K * D_B * 4;
223#endif
224
225    int line_stride_matrix_C_in_bytes = M * 4;
226
227    const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
228    const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
229
230    size_t sub_block_id_m = get_local_id(0);
231
232    uint b_localOffsetInWords = (sub_block_id_m/16)*16
233                           + ((((sub_block_id_m)>>0)&1)<<2)
234                           + ((((sub_block_id_m)>>1)&1)<<3)
235                           + ((((sub_block_id_m)>>2)&1)<<0)
236                           + ((((sub_block_id_m)>>3)&1)<<1);
237
238    uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
239    uint b_globalOffsetInWords00, b_globalOffsetInWords16;
240#ifdef KQV
241    b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
242    b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
243    uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
244    uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
245#else
246    b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
247    b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
248    uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
249    uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
250#endif
251
252    __local float matrix_B_local[1024];
253
254    for (uint step=0; step < K; step+=TILESIZE_K) {
255        size_t sub_block_id_m = get_local_id(0);
256        regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
257
258        uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
259        uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
260
261        regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
262        regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
263
264        mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, &regC0, &regC1);
265
266        subMatrixAStartInElements += TILESIZE_K;
267        subMatrixBStartInElements += TILESIZE_K;
268    }
269
270    uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
271    mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
272}
273