1#pragma OPENCL EXTENSION cl_khr_fp16 : enable
2#pragma OPENCL EXTENSION cl_khr_subgroups : enable
3
4#define LM_FIRST_256B 0
5#define LM_SECOND_256B 64
6#define LM_THIRD_256B 128
7#define LM_FOURTH_256B 192
8
9
10inline float16 mm_load_a(
11 image1d_buffer_t matrix_A,
12 uint subMatrixAStartInElements,
13 int nb01,
14 int line_stride_matrix_A_in_bytes
15) {
16 __private float8 regA;
17 size_t sub_block_id_m = get_local_id(0);
18
19#ifdef KQV
20 uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * nb01/4);
21#else // KQ
22 uint a_texCoord = subMatrixAStartInElements/2 + (sub_block_id_m * line_stride_matrix_A_in_bytes/4);
23#endif
24
25 regA.s0123 = read_imagef(matrix_A, a_texCoord/4);
26 regA.s4567 = read_imagef(matrix_A, (a_texCoord+4)/4);
27
28 return convert_float16(as_half16(regA));
29}
30
31inline float4 alu_32(
32 float16 regA,
33 __local float4* matrix_B_vec
34) {
35
36 __private float4 rC = 0;
37 int i = get_sub_group_id() * 64;
38
39 rC += regA.s0 * matrix_B_vec[i];
40 rC += regA.s1 * matrix_B_vec[i + 16];
41 rC += regA.s4 * matrix_B_vec[i + 1];
42 rC += regA.s5 * matrix_B_vec[i + 17];
43 rC += regA.s8 * matrix_B_vec[i + 2];
44 rC += regA.s9 * matrix_B_vec[i + 18];
45 rC += regA.sc * matrix_B_vec[i + 3];
46 rC += regA.sd * matrix_B_vec[i + 19];
47
48 i += 32;
49
50 rC += regA.s2 * matrix_B_vec[i];
51 rC += regA.s3 * matrix_B_vec[i + 16];
52 rC += regA.s6 * matrix_B_vec[i + 1];
53 rC += regA.s7 * matrix_B_vec[i + 17];
54 rC += regA.sa * matrix_B_vec[i + 2];
55 rC += regA.sb * matrix_B_vec[i + 18];
56 rC += regA.se * matrix_B_vec[i + 3];
57 rC += regA.sf * matrix_B_vec[i + 19];
58
59 return rC;
60}
61
62inline float16 alu_16(
63 float16 regA,
64 __local float* matrix_B_local
65) {
66 float16 out;
67 __local float4* matrix_B_vec = (__local float4*)matrix_B_local;
68
69 out.s0123 = alu_32(regA, matrix_B_vec);
70 out.s4567 = alu_32(regA, matrix_B_vec + 4);
71 out.s89ab = alu_32(regA, matrix_B_vec + 8);
72 out.scdef = alu_32(regA, matrix_B_vec + 12);
73
74 return out;
75}
76
77inline void mm_mad(
78 __local float* matrix_B_local,
79 float16 regA,
80 float8 regB,
81 uint b_localOffsetInWords,
82 float16* regC0_ptr,
83 float16* regC1_ptr
84) {
85 int offset = b_localOffsetInWords + get_sub_group_id() * 256;
86
87 matrix_B_local[offset + LM_FIRST_256B] = regB.s0;
88 matrix_B_local[offset + LM_SECOND_256B] = regB.s1;
89 matrix_B_local[offset + LM_THIRD_256B] = regB.s2;
90 matrix_B_local[offset + LM_FOURTH_256B] = regB.s3;
91
92 float16 add0 = alu_16(regA, matrix_B_local);
93 *regC0_ptr += add0;
94
95 matrix_B_local[offset + LM_FIRST_256B] = regB.s4;
96 matrix_B_local[offset + LM_SECOND_256B] = regB.s5;
97 matrix_B_local[offset + LM_THIRD_256B] = regB.s6;
98 matrix_B_local[offset + LM_FOURTH_256B] = regB.s7;
99
100 float16 add1 = alu_16(regA, matrix_B_local);
101 *regC1_ptr += add1;
102}
103
104inline void mm_store_c_N(
105 __write_only image1d_buffer_t matrix_C,
106 float16 regC0,
107 float16 regC1,
108 uint subMatrixCStartInElements,
109 int line_stride_matrix_C_in_bytes,
110 int mask
111) {
112 size_t sub_block_id_m = get_local_id(0);
113
114 uint strideInWords = line_stride_matrix_C_in_bytes/4;
115 uint c_coordInWords_0 = (subMatrixCStartInElements + sub_block_id_m);
116
117 uint c_coordInWords_1 = c_coordInWords_0 + 1 * strideInWords;
118 uint c_coordInWords_2 = c_coordInWords_0 + 2 * strideInWords;
119 uint c_coordInWords_3 = c_coordInWords_0 + 3 * strideInWords;
120 uint c_coordInWords_4 = c_coordInWords_0 + 4 * strideInWords;
121 uint c_coordInWords_5 = c_coordInWords_0 + 5 * strideInWords;
122 uint c_coordInWords_6 = c_coordInWords_0 + 6 * strideInWords;
123 uint c_coordInWords_7 = c_coordInWords_0 + 7 * strideInWords;
124 uint c_coordInWords_8 = c_coordInWords_0 + 8 * strideInWords;
125 uint c_coordInWords_9 = c_coordInWords_0 + 9 * strideInWords;
126 uint c_coordInWords_10 = c_coordInWords_0 + 10 * strideInWords;
127 uint c_coordInWords_11 = c_coordInWords_0 + 11 * strideInWords;
128 uint c_coordInWords_12 = c_coordInWords_0 + 12 * strideInWords;
129 uint c_coordInWords_13 = c_coordInWords_0 + 13 * strideInWords;
130 uint c_coordInWords_14 = c_coordInWords_0 + 14 * strideInWords;
131 uint c_coordInWords_15 = c_coordInWords_0 + 15 * strideInWords;
132 uint c_coordInWords_16 = c_coordInWords_0 + 16 * strideInWords;
133 uint c_coordInWords_17 = c_coordInWords_0 + 17 * strideInWords;
134 uint c_coordInWords_18 = c_coordInWords_0 + 18 * strideInWords;
135 uint c_coordInWords_19 = c_coordInWords_0 + 19 * strideInWords;
136 uint c_coordInWords_20 = c_coordInWords_0 + 20 * strideInWords;
137 uint c_coordInWords_21 = c_coordInWords_0 + 21 * strideInWords;
138 uint c_coordInWords_22 = c_coordInWords_0 + 22 * strideInWords;
139 uint c_coordInWords_23 = c_coordInWords_0 + 23 * strideInWords;
140 uint c_coordInWords_24 = c_coordInWords_0 + 24 * strideInWords;
141 uint c_coordInWords_25 = c_coordInWords_0 + 25 * strideInWords;
142 uint c_coordInWords_26 = c_coordInWords_0 + 26 * strideInWords;
143 uint c_coordInWords_27 = c_coordInWords_0 + 27 * strideInWords;
144 uint c_coordInWords_28 = c_coordInWords_0 + 28 * strideInWords;
145 uint c_coordInWords_29 = c_coordInWords_0 + 29 * strideInWords;
146 uint c_coordInWords_30 = c_coordInWords_0 + 30 * strideInWords;
147 uint c_coordInWords_31 = c_coordInWords_0 + 31 * strideInWords;
148
149 if (mask > 0) { write_imagef(matrix_C, c_coordInWords_0, regC0.s0); }
150 if (mask > 1) { write_imagef(matrix_C, c_coordInWords_1, regC0.s1); }
151 if (mask > 2) { write_imagef(matrix_C, c_coordInWords_2, regC0.s2); }
152 if (mask > 3) { write_imagef(matrix_C, c_coordInWords_3, regC0.s3); }
153 if (mask > 4) { write_imagef(matrix_C, c_coordInWords_4, regC0.s4); }
154 if (mask > 5) { write_imagef(matrix_C, c_coordInWords_5, regC0.s5); }
155 if (mask > 6) { write_imagef(matrix_C, c_coordInWords_6, regC0.s6); }
156 if (mask > 7) { write_imagef(matrix_C, c_coordInWords_7, regC0.s7); }
157 if (mask > 8) { write_imagef(matrix_C, c_coordInWords_8, regC0.s8); }
158 if (mask > 9) { write_imagef(matrix_C, c_coordInWords_9, regC0.s9); }
159 if (mask > 10) { write_imagef(matrix_C, c_coordInWords_10, regC0.sa); }
160 if (mask > 11) { write_imagef(matrix_C, c_coordInWords_11, regC0.sb); }
161 if (mask > 12) { write_imagef(matrix_C, c_coordInWords_12, regC0.sc); }
162 if (mask > 13) { write_imagef(matrix_C, c_coordInWords_13, regC0.sd); }
163 if (mask > 14) { write_imagef(matrix_C, c_coordInWords_14, regC0.se); }
164 if (mask > 15) { write_imagef(matrix_C, c_coordInWords_15, regC0.sf); }
165 if (mask > 16) { write_imagef(matrix_C, c_coordInWords_16, regC1.s0); }
166 if (mask > 17) { write_imagef(matrix_C, c_coordInWords_17, regC1.s1); }
167 if (mask > 18) { write_imagef(matrix_C, c_coordInWords_18, regC1.s2); }
168 if (mask > 19) { write_imagef(matrix_C, c_coordInWords_19, regC1.s3); }
169 if (mask > 20) { write_imagef(matrix_C, c_coordInWords_20, regC1.s4); }
170 if (mask > 21) { write_imagef(matrix_C, c_coordInWords_21, regC1.s5); }
171 if (mask > 22) { write_imagef(matrix_C, c_coordInWords_22, regC1.s6); }
172 if (mask > 23) { write_imagef(matrix_C, c_coordInWords_23, regC1.s7); }
173 if (mask > 24) { write_imagef(matrix_C, c_coordInWords_24, regC1.s8); }
174 if (mask > 25) { write_imagef(matrix_C, c_coordInWords_25, regC1.s9); }
175 if (mask > 26) { write_imagef(matrix_C, c_coordInWords_26, regC1.sa); }
176 if (mask > 27) { write_imagef(matrix_C, c_coordInWords_27, regC1.sb); }
177 if (mask > 28) { write_imagef(matrix_C, c_coordInWords_28, regC1.sc); }
178 if (mask > 29) { write_imagef(matrix_C, c_coordInWords_29, regC1.sd); }
179 if (mask > 30) { write_imagef(matrix_C, c_coordInWords_30, regC1.se); }
180 if (mask > 31) { write_imagef(matrix_C, c_coordInWords_31, regC1.sf); }
181}
182
183#define TILESIZE_K 16
184#define TILESIZE_M 64
185#define TILESIZE_N 32
186#ifdef KQV
187__kernel void mul_mm_f16_f32_kqv(
188#else
189__kernel void mul_mm_f16_f32_kq(
190#endif
191 __read_only image1d_buffer_t matrix_A,
192 int offset0,
193 __global float* matrix_B,
194 int offset1,
195 __write_only image1d_buffer_t matrix_C,
196 int offsetd,
197 int M, int K, int N,
198 int D_A,
199 int D_B,
200 int nb01
201) {
202
203 uint block_id_m = get_global_id(1);
204 uint block_id_n = get_global_id(2) % ((N+TILESIZE_N-1)/TILESIZE_N);
205 uint block_id_d = get_global_id(2) / ((N+TILESIZE_N-1)/TILESIZE_N);
206
207 __private float16 regA;
208 __private float8 regB;
209 __private float16 regC0;
210 __private float16 regC1;
211
212 const uint col = block_id_m * TILESIZE_M;
213 const uint row = block_id_n * TILESIZE_N;
214 const uint depth_A = block_id_d / (D_B/D_A);
215 const uint depth_B = block_id_d;
216
217#ifdef KQV
218 int line_stride_matrix_A_in_bytes = nb01 * M;
219 int line_stride_matrix_B_in_bytes = K * N * 4;
220#else
221 int line_stride_matrix_A_in_bytes = K * D_A * 2;
222 int line_stride_matrix_B_in_bytes = K * D_B * 4;
223#endif
224
225 int line_stride_matrix_C_in_bytes = M * 4;
226
227 const uint strideAinElements = line_stride_matrix_A_in_bytes / 2;
228 const uint strideBinElements = line_stride_matrix_B_in_bytes / 4;
229
230 size_t sub_block_id_m = get_local_id(0);
231
232 uint b_localOffsetInWords = (sub_block_id_m/16)*16
233 + ((((sub_block_id_m)>>0)&1)<<2)
234 + ((((sub_block_id_m)>>1)&1)<<3)
235 + ((((sub_block_id_m)>>2)&1)<<0)
236 + ((((sub_block_id_m)>>3)&1)<<1);
237
238 uint2 b_globalOffsetInWords_xy = {((sub_block_id_m%4)*4), (sub_block_id_m>>2)};
239 uint b_globalOffsetInWords00, b_globalOffsetInWords16;
240#ifdef KQV
241 b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*K;
242 b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * K);
243 uint subMatrixAStartInElements = depth_A * strideAinElements + col * nb01 / 2;
244 uint subMatrixBStartInElements = depth_B * strideBinElements + row * K;
245#else
246 b_globalOffsetInWords00 = b_globalOffsetInWords_xy.x + b_globalOffsetInWords_xy.y*line_stride_matrix_B_in_bytes/4;
247 b_globalOffsetInWords16 = b_globalOffsetInWords00 + (16 * line_stride_matrix_B_in_bytes/4);
248 uint subMatrixAStartInElements = col * strideAinElements + depth_A * K;
249 uint subMatrixBStartInElements = row * strideBinElements + depth_B * K;
250#endif
251
252 __local float matrix_B_local[1024];
253
254 for (uint step=0; step < K; step+=TILESIZE_K) {
255 size_t sub_block_id_m = get_local_id(0);
256 regA = mm_load_a(matrix_A, subMatrixAStartInElements, nb01, line_stride_matrix_A_in_bytes);
257
258 uint b_coordInWords00 = subMatrixBStartInElements + b_globalOffsetInWords00;
259 uint b_coordInWords16 = subMatrixBStartInElements + b_globalOffsetInWords16;
260
261 regB.s0123 = vload4(b_coordInWords00/4, matrix_B);
262 regB.s4567 = vload4(b_coordInWords16/4, matrix_B);
263
264 mm_mad(matrix_B_local, regA, regB, b_localOffsetInWords, ®C0, ®C1);
265
266 subMatrixAStartInElements += TILESIZE_K;
267 subMatrixBStartInElements += TILESIZE_K;
268 }
269
270 uint subMatrixCStartInElements = depth_B * N * M + row * M + col;
271 mm_store_c_N(matrix_C, regC0, regC1, subMatrixCStartInElements, line_stride_matrix_C_in_bytes, (N-block_id_n*32));
272}
273