1#include "binbcast.cuh"
2#include <cstdint>
3#include <utility>
4
5static __device__ __forceinline__ float op_repeat(const float a, const float b) {
6 return b;
7 GGML_UNUSED(a);
8}
9
10static __device__ __forceinline__ float op_add(const float a, const float b) {
11 return a + b;
12}
13
14static __device__ __forceinline__ float op_sub(const float a, const float b) {
15 return a - b;
16}
17
18static __device__ __forceinline__ float op_mul(const float a, const float b) {
19 return a * b;
20}
21
22static __device__ __forceinline__ float op_div(const float a, const float b) {
23 return a / b;
24}
25
26template <float (*bin_op)(const float, const float),
27 typename src0_t,
28 typename src1_t,
29 typename dst_t,
30 typename... src1_ptrs>
31static __global__ void k_bin_bcast(const src0_t * src0,
32 const src1_t * src1,
33 dst_t * dst,
34 const int ne0,
35 const int ne1,
36 const int ne2,
37 const uint3 ne3,
38 const uint3 ne10,
39 const uint3 ne11,
40 const uint3 ne12,
41 const uint3 ne13,
42 /*const int s0,*/
43 const int s1,
44 const int s2,
45 const int s3,
46 const int s00,
47 const int s01,
48 const int s02,
49 const int s03,
50 const int s10,
51 const int s11,
52 const int s12,
53 const int s13,
54 src1_ptrs... src1s) {
55 const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
56 const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
57 const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
58 const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
59
60 if (i0s >= (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) {
61 return;
62 }
63
64 const uint32_t i11 = fastmodulo(i1, ne11);
65 const uint32_t i12 = fastmodulo(i2, ne12);
66 const uint32_t i13 = fastmodulo(i3, ne13);
67
68 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
69 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
70 const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
71
72 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
73 dst_t * dst_row = dst + i_dst;
74
75 for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
76 const uint32_t i10 = fastmodulo(i0, ne10);
77
78 float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
79 if constexpr (sizeof...(src1_ptrs) > 0) {
80 result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
81 } else {
82 result = bin_op(result, (float)src1[i_src1 + i10*s10]);
83 }
84
85 dst_row[i0] = (dst_t) result;
86 }
87}
88
89template <float (*bin_op)(const float, const float),
90 typename src0_t,
91 typename src1_t,
92 typename dst_t,
93 typename... src1_ptrs>
94static __global__ void k_bin_bcast_unravel(const src0_t * src0,
95 const src1_t * src1,
96 dst_t * dst,
97 const uint3 ne0,
98 const uint3 ne1,
99 const uint3 ne2,
100 const uint32_t ne3,
101 const uint3 prod_012,
102 const uint3 prod_01,
103 const uint3 ne10,
104 const uint3 ne11,
105 const uint3 ne12,
106 const uint3 ne13,
107 /*const int s0,*/
108 const int s1,
109 const int s2,
110 const int s3,
111 const int s00,
112 const int s01,
113 const int s02,
114 const int s03,
115 const int s10,
116 const int s11,
117 const int s12,
118 const int s13,
119 src1_ptrs... src1s) {
120 const int i = blockDim.x*blockIdx.x + threadIdx.x;
121
122 const uint32_t i3 = fastdiv(i, prod_012);
123 const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
124 const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
125 const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
126
127 if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
128 return;
129 }
130
131 const int i11 = fastmodulo(i1, ne11);
132 const int i12 = fastmodulo(i2, ne12);
133 const int i13 = fastmodulo(i3, ne13);
134
135 const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
136 const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
137 const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
138
139 const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
140 dst_t * dst_row = dst + i_dst;
141
142 const int i10 = fastmodulo(i0, ne10);
143
144 float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
145 if constexpr (sizeof...(src1_ptrs) > 0) {
146 result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
147 } else {
148 result = bin_op(result, (float)src1[i_src1 + i10*s10]);
149 }
150
151 dst_row[i0] = (dst_t) result;
152}
153
154template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
155static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
156 const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
157 cudaStream_t stream, std::index_sequence<I...>) {
158 GGML_TENSOR_BINARY_OP_LOCALS
159
160 int nr0 = ne10 / ne0;
161 int nr1 = ne11 / ne1;
162 int nr2 = ne12 / ne2;
163 int nr3 = ne13 / ne3;
164
165 int nr[4] = { nr0, nr1, nr2, nr3 };
166
167 int64_t cne[] = { ne0, ne1, ne2, ne3 };
168 int64_t cne0[] = { ne00, ne01, ne02, ne03 };
169 int64_t cne1[] = { ne10, ne11, ne12, ne13 };
170
171 size_t cnb[] = { nb0, nb1, nb2, nb3 };
172 size_t cnb0[] = { nb00, nb01, nb02, nb03 };
173 size_t cnb1[] = { nb10, nb11, nb12, nb13 };
174
175 auto collapse = [](int64_t cne[]) {
176 cne[0] *= cne[1];
177 cne[1] = cne[2];
178 cne[2] = cne[3];
179 cne[3] = 1;
180 };
181
182 auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
183 cnb[1] *= cne[1];
184 cnb[2] *= cne[2];
185 cnb[3] *= cne[3];
186 };
187
188 if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
189 for (int i = 0; i < 4; i++) {
190 if (nr[i] != 1) {
191 break;
192 }
193 if (i > 0) {
194 collapse_nb(cnb, cne);
195 collapse_nb(cnb0, cne0);
196 collapse_nb(cnb1, cne1);
197 collapse(cne);
198 collapse(cne0);
199 collapse(cne1);
200 }
201 }
202 }
203
204 {
205 int64_t ne0 = cne[0];
206 int64_t ne1 = cne[1];
207 int64_t ne2 = cne[2];
208 int64_t ne3 = cne[3];
209
210 //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
211 //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
212 //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
213 //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
214
215 size_t nb0 = cnb[0];
216 size_t nb1 = cnb[1];
217 size_t nb2 = cnb[2];
218 size_t nb3 = cnb[3];
219
220 size_t nb00 = cnb0[0];
221 size_t nb01 = cnb0[1];
222 size_t nb02 = cnb0[2];
223 size_t nb03 = cnb0[3];
224
225 size_t nb10 = cnb1[0];
226 size_t nb11 = cnb1[1];
227 size_t nb12 = cnb1[2];
228 size_t nb13 = cnb1[3];
229
230 //size_t s0 = nb0 / sizeof(dst_t);
231 size_t s1 = nb1 / sizeof(dst_t);
232 size_t s2 = nb2 / sizeof(dst_t);
233 size_t s3 = nb3 / sizeof(dst_t);
234
235 size_t s10 = nb10 / sizeof(src1_t);
236 size_t s11 = nb11 / sizeof(src1_t);
237 size_t s12 = nb12 / sizeof(src1_t);
238 size_t s13 = nb13 / sizeof(src1_t);
239
240 size_t s00 = nb00 / sizeof(src0_t);
241 size_t s01 = nb01 / sizeof(src0_t);
242 size_t s02 = nb02 / sizeof(src0_t);
243 size_t s03 = nb03 / sizeof(src0_t);
244
245 GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
246 GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
247 GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
248 GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
249
250 GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
251 GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
252 GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
253 GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
254
255 GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
256 GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
257 GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
258 GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
259
260 const int block_size = 128;
261
262 int64_t hne0 = std::max(ne0 / 2LL, 1LL);
263
264 dim3 block_dims;
265 block_dims.x = std::min<unsigned int>(hne0, block_size);
266 block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
267 block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
268
269 dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
270 (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
271
272 const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
273 const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
274 const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
275 const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
276
277 if (block_nums.z > 65535 || block_nums.y > 65535) {
278 int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
279 const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
280 const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
281 const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
282 const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
283 const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
284
285 if constexpr (sizeof...(I) > 0) {
286 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
287 src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
288 ne12, ne13,
289 /*s0,*/ s1, s2, s3,
290 s00, s01, s02, s03,
291 s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
292 } else {
293 k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
294 <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
295 ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
296 /*s0,*/ s1, s2, s3,
297 s00, s01, s02, s03,
298 s10, s11, s12, s13);
299 }
300 } else {
301 const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
302 if constexpr (sizeof...(I) > 0) {
303 k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
304 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
305 /*s0,*/ s1, s2, s3,
306 s00 ,s01, s02, s03,
307 s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
308 } else {
309 k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
310 src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
311 /*s0,*/ s1, s2, s3,
312 s00, s01, s02, s03,
313 s10, s11, s12, s13);
314 }
315 }
316 }
317}
318
319template <typename T>
320static __global__ void k_repeat_back(
321 const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
322 const size_t s00, const size_t s01, const size_t s02, const size_t s03,
323 const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
324
325 const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
326 const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
327 const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
328 const int64_t tid2 = tid23 % ne2;
329 const int64_t tid3 = tid23 / ne2;
330
331 if (tid0 >= ne0) {
332 return;
333 }
334
335 T sum = 0;
336 for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
337 for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
338 for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
339 for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
340 sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
341 }
342 }
343 }
344 }
345 dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
346}
347
348template <float (*bin_op)(const float, const float), int n_fuse = 1>
349struct bin_bcast_cuda {
350 template<typename src0_t, typename src1_t, typename dst_t>
351 void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
352 const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
353 cudaStream_t stream) {
354 launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
355 src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
356 }
357};
358
359template <typename T>
360static void repeat_back_cuda(
361 const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
362 const size_t s00, const size_t s01, const size_t s02, const size_t s03,
363 const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
364
365 const dim3 block_dims(WARP_SIZE, 1, 1);
366 const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
367 k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
368 (src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
369}
370
371template<class op>
372static void ggml_cuda_op_bin_bcast(
373 const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
374 const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
375
376 GGML_ASSERT(src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
377
378 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
379 op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
380 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
381 op()(src0, src1, dst, (const half *) src0_dd, (const half *)src1_dd, (half *) dst_dd, stream);
382 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
383 op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
384 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
385 op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
386 } else {
387 fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
388 ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
389 GGML_ABORT("fatal error");
390 }
391}
392
393void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
394 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
395}
396
397void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
398 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
399}
400
401void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
402 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_sub>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
403}
404
405void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
406 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
407}
408
409void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
410 ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
411}
412
413template <float (*op)(const float, const float), int n_fuse>
414static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
415 cudaStream_t stream = ctx.stream();
416
417 const ggml_tensor * src0 = dst->src[0];
418 const ggml_tensor * src1 = dst->src[1];
419
420 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
421 launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
422 (const float *) src0->data, (const float *) src1->data, (float *) dst->data,
423 stream, std::make_index_sequence<n_fuse>{});
424 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
425 launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
426 (const half *) src0->data, (const half *) src1->data, (half *) dst->data,
427 stream, std::make_index_sequence<n_fuse>{});
428 } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
429 launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
430 (const half *) src0->data, (const float *) src1->data, (half *) dst->data,
431 stream, std::make_index_sequence<n_fuse>{});
432 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
433 launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
434 (const half *) src0->data, (const float *) src1->data, (float *) dst->data,
435 stream, std::make_index_sequence<n_fuse>{});
436 } else {
437 fprintf(stderr,
438 "%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
439 __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
440 GGML_ABORT("fatal error");
441 }
442}
443
444
445void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
446 GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
447
448 switch (n_fuse) {
449 case 2:
450 ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
451 break;
452 case 3:
453 ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
454 break;
455 case 4:
456 ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
457 break;
458 case 5:
459 ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
460 break;
461 case 6:
462 ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
463 break;
464 case 7:
465 ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
466 break;
467 case 8:
468 ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
469 break;
470 default:
471 GGML_ASSERT(false && "Unsupported n_fuse value");
472 }
473}
474
475void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
476 const ggml_tensor * src0 = dst->src[0];
477
478 GGML_ASSERT(src0->type == dst->type);
479 GGML_ASSERT(ggml_is_contiguous(dst));
480 GGML_ASSERT(ggml_can_repeat(dst, src0));
481
482 cudaStream_t stream = ctx.stream();
483
484 GGML_TENSOR_UNARY_OP_LOCALS;
485
486 GGML_ASSERT(ne2*ne3 <= (1 << 15));
487
488 const size_t ts = ggml_type_size(src0->type);
489 const size_t s00 = nb00 / ts;
490 const size_t s01 = nb01 / ts;
491 const size_t s02 = nb02 / ts;
492 const size_t s03 = nb03 / ts;
493
494 switch (dst->type) {
495 case GGML_TYPE_F32: {
496 const float * src0_d = (const float *) src0->data;
497 float * dst_d = (float *) dst->data;
498 repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
499 } break;
500 default: {
501 GGML_ASSERT(false);
502 } break;
503 }
504}