1#include "ops.h"
2
3#include "ggml-cpu.h"
4#include "ggml-impl.h"
5#include "binary-ops.h"
6#include "ggml.h"
7#include "unary-ops.h"
8#include "vec.h"
9
10#include <algorithm>
11#include <cfloat>
12#include <cmath>
13
14// ggml_compute_forward_dup
15
16static void ggml_compute_forward_dup_same_cont(
17 const ggml_compute_params * params,
18 ggml_tensor * dst) {
19
20 const ggml_tensor * src0 = dst->src[0];
21
22 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
23 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
24 GGML_ASSERT(src0->type == dst->type);
25
26 const size_t nb0 = ggml_type_size(src0->type);
27
28 const int ith = params->ith; // thread index
29 const int nth = params->nth; // number of threads
30
31 // parallelize by blocks
32 const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
33 const int dr = (nk + nth - 1) / nth;
34 const int k0 = dr * ith;
35 const int k1 = MIN(k0 + dr, nk);
36
37 if (k0 < k1) {
38 memcpy(
39 ((char *) dst->data + k0*nb0),
40 ((char *) src0->data + k0*nb0),
41 (k1 - k0) * nb0);
42 }
43}
44
45template<typename src_t, typename dst_t>
46static void ggml_compute_forward_dup_flt(
47 const ggml_compute_params * params,
48 ggml_tensor * dst) {
49
50 const ggml_tensor * src0 = dst->src[0];
51
52 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
53 GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
54
55 GGML_TENSOR_UNARY_OP_LOCALS
56
57 const int ith = params->ith; // thread index
58 const int nth = params->nth; // number of threads
59
60 // parallelize by rows
61 const int nr = ne01;
62 // number of rows per thread
63 const int dr = (nr + nth - 1) / nth;
64 // row range for this thread
65 const int ir0 = dr * ith;
66 const int ir1 = MIN(ir0 + dr, nr);
67
68 // case: type & row size equal
69 if (src0->type == dst->type &&
70 ne00 == ne0 &&
71 nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
72 // copy by rows
73 const size_t rs = ne00*nb00;
74 for (int64_t i03 = 0; i03 < ne03; i03++) {
75 for (int64_t i02 = 0; i02 < ne02; i02++) {
76 for (int64_t i01 = ir0; i01 < ir1; i01++) {
77 memcpy(
78 ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
79 ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
80 rs);
81 }
82 }
83 }
84 return;
85 }
86
87 // case: dst tensor is contiguous
88 if (ggml_is_contiguous(dst)) {
89 if (nb00 == sizeof(src_t)) {
90 if constexpr (std::is_same_v<dst_t, src_t>) {
91 // same type
92 size_t id = 0;
93 const size_t rs = ne00 * nb00;
94 char * dst_ptr = (char *) dst->data;
95
96 for (int i03 = 0; i03 < ne03; i03++) {
97 for (int i02 = 0; i02 < ne02; i02++) {
98 id += rs * ir0;
99 for (int i01 = ir0; i01 < ir1; i01++) {
100 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
101 memcpy(dst_ptr + id, src0_ptr, rs);
102 id += rs;
103 }
104 id += rs * (ne01 - ir1);
105 }
106 }
107 } else {
108 // casting between non-quantized types
109 size_t id = 0;
110 dst_t * dst_ptr = (dst_t *) dst->data;
111
112 for (int i03 = 0; i03 < ne03; i03++) {
113 for (int i02 = 0; i02 < ne02; i02++) {
114 id += ne00 * ir0;
115 for (int i01 = ir0; i01 < ir1; i01++) {
116 const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
117 for (int i00 = 0; i00 < ne00; i00++) {
118 float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
119 dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
120 id++;
121 }
122 }
123 id += ne00 * (ne01 - ir1);
124 }
125 }
126 }
127 } else {
128 //printf("%s: this is not optimal - fix me\n", __func__);
129
130 size_t id = 0;
131 dst_t * dst_ptr = (dst_t *) dst->data;
132
133 for (int i03 = 0; i03 < ne03; i03++) {
134 for (int i02 = 0; i02 < ne02; i02++) {
135 id += ne00 * ir0;
136 for (int i01 = ir0; i01 < ir1; i01++) {
137 for (int i00 = 0; i00 < ne00; i00++) {
138 const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
139
140 float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
141 dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
142 id++;
143 }
144 }
145 id += ne00 * (ne01 - ir1);
146 }
147 }
148 }
149 return;
150 }
151
152 // dst counters
153 int64_t i10 = 0;
154 int64_t i11 = 0;
155 int64_t i12 = 0;
156 int64_t i13 = 0;
157
158 if constexpr (std::is_same_v<dst_t, src_t>) {
159 for (int64_t i03 = 0; i03 < ne03; i03++) {
160 for (int64_t i02 = 0; i02 < ne02; i02++) {
161 i10 += ne00 * ir0;
162 while (i10 >= ne0) {
163 i10 -= ne0;
164 if (++i11 == ne1) {
165 i11 = 0;
166 if (++i12 == ne2) {
167 i12 = 0;
168 if (++i13 == ne3) {
169 i13 = 0;
170 }
171 }
172 }
173 }
174 for (int64_t i01 = ir0; i01 < ir1; i01++) {
175 for (int64_t i00 = 0; i00 < ne00; i00++) {
176 const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
177 char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
178
179 memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
180
181 if (++i10 == ne00) {
182 i10 = 0;
183 if (++i11 == ne01) {
184 i11 = 0;
185 if (++i12 == ne02) {
186 i12 = 0;
187 if (++i13 == ne03) {
188 i13 = 0;
189 }
190 }
191 }
192 }
193 }
194 }
195 i10 += ne00 * (ne01 - ir1);
196 while (i10 >= ne0) {
197 i10 -= ne0;
198 if (++i11 == ne1) {
199 i11 = 0;
200 if (++i12 == ne2) {
201 i12 = 0;
202 if (++i13 == ne3) {
203 i13 = 0;
204 }
205 }
206 }
207 }
208 }
209 }
210
211 } else {
212 for (int64_t i03 = 0; i03 < ne03; i03++) {
213 for (int64_t i02 = 0; i02 < ne02; i02++) {
214 i10 += ne00 * ir0;
215 while (i10 >= ne0) {
216 i10 -= ne0;
217 if (++i11 == ne1) {
218 i11 = 0;
219 if (++i12 == ne2) {
220 i12 = 0;
221 if (++i13 == ne3) {
222 i13 = 0;
223 }
224 }
225 }
226 }
227 for (int64_t i01 = ir0; i01 < ir1; i01++) {
228 for (int64_t i00 = 0; i00 < ne00; i00++) {
229 const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
230 char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
231
232 float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
233 *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
234
235 if (++i10 == ne0) {
236 i10 = 0;
237 if (++i11 == ne1) {
238 i11 = 0;
239 if (++i12 == ne2) {
240 i12 = 0;
241 if (++i13 == ne3) {
242 i13 = 0;
243 }
244 }
245 }
246 }
247 }
248 }
249 i10 += ne00 * (ne01 - ir1);
250 while (i10 >= ne0) {
251 i10 -= ne0;
252 if (++i11 == ne1) {
253 i11 = 0;
254 if (++i12 == ne2) {
255 i12 = 0;
256 if (++i13 == ne3) {
257 i13 = 0;
258 }
259 }
260 }
261 }
262 }
263 }
264 }
265}
266
267
268template<typename src_t>
269static void ggml_compute_forward_dup_to_q(
270 const ggml_compute_params * params,
271 ggml_tensor * dst) {
272
273 const ggml_tensor * src0 = dst->src[0];
274
275 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
276 GGML_ASSERT(!ggml_is_quantized(src0->type));
277
278 GGML_TENSOR_UNARY_OP_LOCALS
279
280 const int ith = params->ith; // thread index
281 const int nth = params->nth; // number of threads
282
283 // parallelize by rows
284 const int nr = ne01;
285 // number of rows per thread
286 const int dr = (nr + nth - 1) / nth;
287 // row range for this thread
288 const int ir0 = dr * ith;
289 const int ir1 = MIN(ir0 + dr, nr);
290
291 if (ggml_is_contiguous(dst) &&
292 nb00 == sizeof(src_t) &&
293 ggml_get_type_traits_cpu(dst->type)->from_float) {
294 // casting non-quantized types --> intermediate f32 --> quantized
295 ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
296 float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
297
298 size_t id = 0;
299 size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
300 char * dst_ptr = (char *) dst->data;
301
302 for (int i03 = 0; i03 < ne03; i03++) {
303 for (int i02 = 0; i02 < ne02; i02++) {
304 id += rs * ir0;
305 for (int i01 = ir0; i01 < ir1; i01++) {
306 const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
307
308 for (int i00 = 0; i00 < ne00; i00++) {
309 src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
310 }
311
312 quantize_row_q(src0_f32, dst_ptr + id, ne00);
313 id += rs;
314 }
315 id += rs * (ne01 - ir1);
316 }
317 }
318 } else {
319 // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
320 GGML_ABORT("not implemented");
321 }
322}
323
324// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
325static void ggml_compute_forward_dup_bytes(
326 const ggml_compute_params * params,
327 ggml_tensor * dst) {
328 const ggml_tensor * src0 = dst->src[0];
329
330 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
331 GGML_ASSERT(src0->type == dst->type);
332
333 GGML_TENSOR_UNARY_OP_LOCALS;
334
335 if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
336 ggml_compute_forward_dup_same_cont(params, dst);
337 return;
338 }
339
340 const size_t type_size = ggml_type_size(src0->type);
341
342 const int ith = params->ith; // thread index
343 const int nth = params->nth; // number of threads
344
345 // parallelize by rows
346 const int nr = ne01;
347 // number of rows per thread
348 const int dr = (nr + nth - 1) / nth;
349 // row range for this thread
350 const int ir0 = dr * ith;
351 const int ir1 = MIN(ir0 + dr, nr);
352
353 if (src0->type == dst->type &&
354 ggml_are_same_shape(src0, dst) &&
355 nb00 == type_size && nb0 == type_size) {
356 // copy by rows
357 const size_t rs = ggml_row_size(src0->type, ne00);
358 for (int64_t i03 = 0; i03 < ne03; i03++) {
359 for (int64_t i02 = 0; i02 < ne02; i02++) {
360 for (int64_t i01 = ir0; i01 < ir1; i01++) {
361 memcpy(
362 ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
363 ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
364 rs);
365 }
366 }
367 }
368 return;
369 }
370
371 if (ggml_is_contiguous(dst)) {
372 size_t id = 0;
373 char * dst_ptr = (char *) dst->data;
374 const size_t rs = ne00 * type_size;
375
376 if (nb00 == type_size) {
377 // src0 is contigous on first dimension, copy by rows
378 for (int64_t i03 = 0; i03 < ne03; i03++) {
379 for (int64_t i02 = 0; i02 < ne02; i02++) {
380 id += rs * ir0;
381 for (int64_t i01 = ir0; i01 < ir1; i01++) {
382 const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
383 memcpy(dst_ptr + id, src0_ptr, rs);
384 id += rs;
385 }
386 id += rs * (ne01 - ir1);
387 }
388 }
389 } else {
390 //printf("%s: this is not optimal - fix me\n", __func__);
391
392 for (int64_t i03 = 0; i03 < ne03; i03++) {
393 for (int64_t i02 = 0; i02 < ne02; i02++) {
394 id += rs * ir0;
395 for (int64_t i01 = ir0; i01 < ir1; i01++) {
396 for (int64_t i00 = 0; i00 < ne00; i00++) {
397 const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
398 memcpy(dst_ptr + id, src0_ptr, type_size);
399
400 id += type_size;
401 }
402 }
403 id += rs * (ne01 - ir1);
404 }
405 }
406 }
407
408 return;
409 }
410
411 // dst counters
412 int64_t k10 = 0;
413 int64_t i11 = 0;
414 int64_t i12 = 0;
415 int64_t i13 = 0;
416
417 // number of blocks in a row
418 const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
419 const int64_t nk0 = ne0 / ggml_blck_size(dst->type);
420
421 for (int64_t i03 = 0; i03 < ne03; i03++) {
422 for (int64_t i02 = 0; i02 < ne02; i02++) {
423 k10 += nk00 * ir0;
424 while (k10 >= nk0) {
425 k10 -= nk0;
426 if (++i11 == ne1) {
427 i11 = 0;
428 if (++i12 == ne2) {
429 i12 = 0;
430 if (++i13 == ne3) {
431 i13 = 0;
432 }
433 }
434 }
435 }
436 for (int64_t i01 = ir0; i01 < ir1; i01++) {
437 for (int64_t k00 = 0; k00 < nk00; k00++) {
438 const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
439 char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
440
441 memcpy(dst_ptr, src0_ptr, type_size);
442
443 if (++k10 == nk0) {
444 k10 = 0;
445 if (++i11 == ne1) {
446 i11 = 0;
447 if (++i12 == ne2) {
448 i12 = 0;
449 if (++i13 == ne3) {
450 i13 = 0;
451 }
452 }
453 }
454 }
455 }
456 }
457 k10 += nk00 * (ne01 - ir1);
458 while (k10 >= nk0) {
459 k10 -= nk0;
460 if (++i11 == ne1) {
461 i11 = 0;
462 if (++i12 == ne2) {
463 i12 = 0;
464 if (++i13 == ne3) {
465 i13 = 0;
466 }
467 }
468 }
469 }
470 }
471 }
472}
473
474static void ggml_compute_forward_dup_from_q(
475 const ggml_compute_params * params,
476 ggml_tensor * dst) {
477
478 const ggml_tensor * src0 = dst->src[0];
479 const ggml_tensor * src1 = dst->src[1];
480
481 GGML_TENSOR_BINARY_OP_LOCALS
482
483 const ggml_type type = src0->type;
484 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
485
486 size_t qk = ggml_blck_size(type);
487 const int64_t nr = ggml_nelements(src1) / qk;
488
489 // destination must be contiguous in the first dimension
490 GGML_ASSERT(nb10 == ggml_type_size(dst->type));
491 // must either have first dimension large enough to hold a row, or fully contiguous
492 GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
493
494 const int ith = params->ith;
495 const int nth = params->nth;
496
497 const int dr = (nr + nth - 1)/nth;
498
499 // row range for this thread
500 const int ir0 = dr*ith;
501 const int ir1 = MIN(ir0 + dr, nr);
502
503 for (int64_t ir = ir0; ir < ir1; ++ir) {
504
505 uint32_t i = ir * qk;
506
507 const int64_t i03 = i/(ne00 * ne01 * ne02);
508 const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
509 const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
510 const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
511 const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
512
513 const int64_t i13 = i/(ne10 * ne11 * ne12);
514 const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
515 const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
516 const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
517 const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
518
519 dequantize_row_q(
520 (const void *) ((char *) src0->data + x_offset),
521 (float *) ((char *) dst->data + dst_offset), qk);
522 }
523}
524
525void ggml_compute_forward_dup(
526 const ggml_compute_params * params,
527 ggml_tensor * dst) {
528
529 const ggml_tensor * src0 = dst->src[0];
530
531 if (src0->type == dst->type) {
532 ggml_compute_forward_dup_bytes(params, dst);
533 return;
534 }
535
536 switch (src0->type) {
537 case GGML_TYPE_F16:
538 {
539 /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
540 else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
541 else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_fp16_t, float >(params, dst);
542 else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
543 } break;
544 case GGML_TYPE_BF16:
545 {
546 /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
547 else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
548 else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<ggml_bf16_t, float >(params, dst);
549 else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
550 } break;
551 case GGML_TYPE_F32:
552 {
553 /**/ if (dst->type == GGML_TYPE_F16) ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
554 else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
555 else if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<float, float >(params, dst);
556 else if (dst->type == GGML_TYPE_I32) ggml_compute_forward_dup_flt<float, int32_t >(params, dst);
557 else ggml_compute_forward_dup_to_q<float>(params, dst);
558 } break;
559 case GGML_TYPE_I32:
560 {
561 if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
562 else GGML_ABORT("not implemented");
563 } break;
564 default:
565 {
566 if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
567 ggml_compute_forward_dup_from_q(params, dst);
568 break;
569 }
570 GGML_ABORT("fatal error");
571 }
572 }
573}
574
575// ggml_compute_forward_add
576
577static void ggml_compute_forward_add_q_f32(
578 const ggml_compute_params * params,
579 ggml_tensor * dst) {
580
581 const ggml_tensor * src0 = dst->src[0];
582 const ggml_tensor * src1 = dst->src[1];
583
584 GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
585
586 const int nr = ggml_nrows(src0);
587
588 GGML_TENSOR_BINARY_OP_LOCALS
589
590 const int ith = params->ith;
591 const int nth = params->nth;
592
593 const ggml_type type = src0->type;
594 const ggml_type dtype = dst->type;
595 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
596 ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
597
598 // we don't support permuted src0 or src1
599 GGML_ASSERT(nb00 == ggml_type_size(type));
600 GGML_ASSERT(nb10 == sizeof(float));
601
602 // dst cannot be transposed or permuted
603 GGML_ASSERT(nb0 <= nb1);
604 GGML_ASSERT(nb1 <= nb2);
605 GGML_ASSERT(nb2 <= nb3);
606
607 GGML_ASSERT(ggml_is_quantized(src0->type));
608 GGML_ASSERT(src1->type == GGML_TYPE_F32);
609
610 // rows per thread
611 const int dr = (nr + nth - 1)/nth;
612
613 // row range for this thread
614 const int ir0 = dr*ith;
615 const int ir1 = MIN(ir0 + dr, nr);
616
617 float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
618
619 for (int ir = ir0; ir < ir1; ++ir) {
620 // src0 indices
621 const int i03 = ir/(ne02*ne01);
622 const int i02 = (ir - i03*ne02*ne01)/ne01;
623 const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
624
625 // src1 and dst are same shape as src0 => same indices
626 const int i13 = i03;
627 const int i12 = i02;
628 const int i11 = i01;
629
630 const int i3 = i03;
631 const int i2 = i02;
632 const int i1 = i01;
633
634 void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
635 float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
636 void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
637
638 assert(ne00 % 32 == 0);
639
640 // unquantize row from src0 to temp buffer
641 dequantize_row_q(src0_row, wdata, ne00);
642 // add src1
643 ggml_vec_acc_f32(ne00, wdata, src1_row);
644 // quantize row to dst
645 if (quantize_row_q != NULL) {
646 quantize_row_q(wdata, dst_row, ne00);
647 } else {
648 memcpy(dst_row, wdata, ne0*nb0);
649 }
650 }
651}
652
653void ggml_compute_forward_add(
654 const ggml_compute_params * params,
655 ggml_tensor * dst) {
656
657 const ggml_tensor * src0 = dst->src[0];
658
659 switch (src0->type) {
660 case GGML_TYPE_F32:
661 case GGML_TYPE_F16:
662 case GGML_TYPE_BF16:
663 {
664 ggml_compute_forward_add_non_quantized(params, dst);
665 } break;
666 case GGML_TYPE_Q4_0:
667 case GGML_TYPE_Q4_1:
668 case GGML_TYPE_Q5_0:
669 case GGML_TYPE_Q5_1:
670 case GGML_TYPE_Q8_0:
671 case GGML_TYPE_MXFP4:
672 case GGML_TYPE_Q2_K:
673 case GGML_TYPE_Q3_K:
674 case GGML_TYPE_Q4_K:
675 case GGML_TYPE_Q5_K:
676 case GGML_TYPE_Q6_K:
677 case GGML_TYPE_TQ1_0:
678 case GGML_TYPE_TQ2_0:
679 case GGML_TYPE_IQ2_XXS:
680 case GGML_TYPE_IQ2_XS:
681 case GGML_TYPE_IQ3_XXS:
682 case GGML_TYPE_IQ1_S:
683 case GGML_TYPE_IQ1_M:
684 case GGML_TYPE_IQ4_NL:
685 case GGML_TYPE_IQ4_XS:
686 case GGML_TYPE_IQ3_S:
687 case GGML_TYPE_IQ2_S:
688 {
689 ggml_compute_forward_add_q_f32(params, dst);
690 } break;
691 default:
692 {
693 GGML_ABORT("fatal error");
694 }
695 }
696}
697
698// ggml_compute_forward_add_id
699
700static void ggml_compute_forward_add_id_f32(
701 const ggml_compute_params * params,
702 ggml_tensor * dst) {
703
704 const ggml_tensor * src0 = dst->src[0];
705 const ggml_tensor * src1 = dst->src[1];
706 const ggml_tensor * src2 = dst->src[2];
707
708 GGML_ASSERT(dst->type == GGML_TYPE_F32);
709 GGML_ASSERT(src0->type == GGML_TYPE_F32);
710 GGML_ASSERT(src1->type == GGML_TYPE_F32);
711 GGML_ASSERT(src2->type == GGML_TYPE_I32);
712
713 GGML_ASSERT(src0->nb[0] == sizeof(float));
714 GGML_ASSERT(src1->nb[0] == sizeof(float));
715
716 const int ith = params->ith;
717 const int nth = params->nth;
718
719 const int nr = ggml_nrows(src0);
720
721 GGML_TENSOR_TERNARY_OP_LOCALS
722
723 GGML_ASSERT( nb0 == sizeof(float));
724 GGML_ASSERT(nb10 == sizeof(float));
725
726 // rows per thread
727 const int dr = (nr + nth - 1)/nth;
728
729 // row range for this thread
730 const int ir0 = dr*ith;
731 const int ir1 = MIN(ir0 + dr, nr);
732
733 for (int ir = ir0; ir < ir1; ++ir) {
734 // src0 indices
735 const int i3 = ir/(ne2*ne1);
736 const int i2 = (ir - i3*ne2*ne1)/ne1;
737 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
738
739 // src1 indices
740 const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
741
742 GGML_ASSERT(i11 >= 0 && i11 < ne11);
743
744 ggml_vec_add_f32(ne0,
745 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
746 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
747 (float *) ((char *) src1->data + i11*nb11));
748 }
749}
750
751void ggml_compute_forward_add_id(
752 const ggml_compute_params * params,
753 ggml_tensor * dst) {
754
755 const ggml_tensor * src0 = dst->src[0];
756
757 switch (src0->type) {
758 case GGML_TYPE_F32:
759 {
760 ggml_compute_forward_add_id_f32(params, dst);
761 } break;
762 default:
763 {
764 GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
765 }
766 }
767}
768
769// ggml_compute_forward_add1
770
771static void ggml_compute_forward_add1_f32(
772 const ggml_compute_params * params,
773 ggml_tensor * dst) {
774
775 const ggml_tensor * src0 = dst->src[0];
776 const ggml_tensor * src1 = dst->src[1];
777
778 GGML_ASSERT(ggml_are_same_shape(src0, dst));
779 GGML_ASSERT(ggml_is_scalar(src1));
780
781 const int ith = params->ith;
782 const int nth = params->nth;
783
784 const int nr = ggml_nrows(src0);
785
786 GGML_TENSOR_UNARY_OP_LOCALS
787
788 GGML_ASSERT( nb0 == sizeof(float));
789 GGML_ASSERT(nb00 == sizeof(float));
790
791 // rows per thread
792 const int dr = (nr + nth - 1)/nth;
793
794 // row range for this thread
795 const int ir0 = dr*ith;
796 const int ir1 = MIN(ir0 + dr, nr);
797
798 for (int ir = ir0; ir < ir1; ++ir) {
799 // src0 and dst are same shape => same indices
800 const int i3 = ir/(ne2*ne1);
801 const int i2 = (ir - i3*ne2*ne1)/ne1;
802 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
803
804#ifdef GGML_USE_ACCELERATE
805 GGML_UNUSED(ggml_vec_add1_f32);
806
807 vDSP_vadd(
808 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
809 (float *) ((char *) src1->data), 0,
810 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
811 ne0);
812#else
813 ggml_vec_add1_f32(ne0,
814 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
815 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
816 *(float *) src1->data);
817#endif
818 }
819}
820
821static void ggml_compute_forward_add1_f16_f32(
822 const ggml_compute_params * params,
823 ggml_tensor * dst) {
824
825 const ggml_tensor * src0 = dst->src[0];
826 const ggml_tensor * src1 = dst->src[1];
827
828 GGML_ASSERT(ggml_are_same_shape(src0, dst));
829 GGML_ASSERT(ggml_is_scalar(src1));
830
831 // scalar to add
832 const float v = *(float *) src1->data;
833
834 const int ith = params->ith;
835 const int nth = params->nth;
836
837 const int nr = ggml_nrows(src0);
838
839 GGML_TENSOR_UNARY_OP_LOCALS
840
841 GGML_ASSERT(src0->type == GGML_TYPE_F16);
842 GGML_ASSERT(src1->type == GGML_TYPE_F32);
843 GGML_ASSERT(dst->type == GGML_TYPE_F16);
844
845 GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
846 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
847
848 // rows per thread
849 const int dr = (nr + nth - 1)/nth;
850
851 // row range for this thread
852 const int ir0 = dr*ith;
853 const int ir1 = MIN(ir0 + dr, nr);
854
855 for (int ir = ir0; ir < ir1; ++ir) {
856 // src0 and dst are same shape => same indices
857 const int i3 = ir/(ne2*ne1);
858 const int i2 = (ir - i3*ne2*ne1)/ne1;
859 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
860
861 ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
862 ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
863 for (int i = 0; i < ne0; i++) {
864 dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
865 }
866 }
867}
868
869static void ggml_compute_forward_add1_f16_f16(
870 const ggml_compute_params * params,
871 ggml_tensor * dst) {
872
873 const ggml_tensor * src0 = dst->src[0];
874 const ggml_tensor * src1 = dst->src[1];
875
876 GGML_ASSERT(ggml_are_same_shape(src0, dst));
877 GGML_ASSERT(ggml_is_scalar(src1));
878
879 // scalar to add
880 const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
881
882 const int ith = params->ith;
883 const int nth = params->nth;
884
885 const int nr = ggml_nrows(src0);
886
887 GGML_TENSOR_UNARY_OP_LOCALS
888
889 GGML_ASSERT(src0->type == GGML_TYPE_F16);
890 GGML_ASSERT(src1->type == GGML_TYPE_F16);
891 GGML_ASSERT(dst->type == GGML_TYPE_F16);
892
893 GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
894 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
895
896 // rows per thread
897 const int dr = (nr + nth - 1)/nth;
898
899 // row range for this thread
900 const int ir0 = dr*ith;
901 const int ir1 = MIN(ir0 + dr, nr);
902
903 for (int ir = ir0; ir < ir1; ++ir) {
904 // src0 and dst are same shape => same indices
905 const int i3 = ir/(ne2*ne1);
906 const int i2 = (ir - i3*ne2*ne1)/ne1;
907 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
908
909 ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
910 ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
911 for (int i = 0; i < ne0; i++) {
912 dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
913 }
914 }
915}
916
917static void ggml_compute_forward_add1_q_f32(
918 const ggml_compute_params * params,
919 ggml_tensor * dst) {
920
921 const ggml_tensor * src0 = dst->src[0];
922 const ggml_tensor * src1 = dst->src[1];
923
924 GGML_ASSERT(ggml_are_same_shape(src0, dst));
925 GGML_ASSERT(ggml_is_scalar(src1));
926
927 // scalar to add
928 const float v = *(float *) src1->data;
929
930 const int ith = params->ith;
931 const int nth = params->nth;
932
933 const int nr = ggml_nrows(src0);
934
935 GGML_TENSOR_UNARY_OP_LOCALS
936
937 const ggml_type type = src0->type;
938 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
939 ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
940
941 // we don't support permuted src0
942 GGML_ASSERT(nb00 == ggml_type_size(type));
943
944 // dst cannot be transposed or permuted
945 GGML_ASSERT(nb0 <= nb1);
946 GGML_ASSERT(nb1 <= nb2);
947 GGML_ASSERT(nb2 <= nb3);
948
949 GGML_ASSERT(ggml_is_quantized(src0->type));
950 GGML_ASSERT(dst->type == src0->type);
951 GGML_ASSERT(src1->type == GGML_TYPE_F32);
952
953 // rows per thread
954 const int dr = (nr + nth - 1)/nth;
955
956 // row range for this thread
957 const int ir0 = dr*ith;
958 const int ir1 = MIN(ir0 + dr, nr);
959
960 float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
961
962 for (int ir = ir0; ir < ir1; ++ir) {
963 // src0 and dst are same shape => same indices
964 const int i3 = ir/(ne2*ne1);
965 const int i2 = (ir - i3*ne2*ne1)/ne1;
966 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
967
968 void * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
969 void * dst_row = (void *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb0 ));
970
971 assert(ne0 % 32 == 0);
972
973 // unquantize row from src0 to temp buffer
974 dequantize_row_q(src0_row, wdata, ne0);
975 // add src1
976 ggml_vec_acc1_f32(ne0, wdata, v);
977 // quantize row to dst
978 quantize_row_q(wdata, dst_row, ne0);
979 }
980}
981
982static void ggml_compute_forward_add1_bf16_f32(
983 const ggml_compute_params * params,
984 ggml_tensor * dst) {
985
986 const ggml_tensor * src0 = dst->src[0];
987 const ggml_tensor * src1 = dst->src[1];
988
989 GGML_ASSERT(ggml_are_same_shape(src0, dst));
990 GGML_ASSERT(ggml_is_scalar(src1));
991
992 // scalar to add
993 const float v = *(float *) src1->data;
994
995 const int ith = params->ith;
996 const int nth = params->nth;
997
998 const int nr = ggml_nrows(src0);
999
1000 GGML_TENSOR_UNARY_OP_LOCALS
1001
1002 GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1003 GGML_ASSERT(src1->type == GGML_TYPE_F32);
1004 GGML_ASSERT(dst->type == GGML_TYPE_BF16);
1005
1006 GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1007 GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1008
1009 // rows per thread
1010 const int dr = (nr + nth - 1)/nth;
1011
1012 // row range for this thread
1013 const int ir0 = dr*ith;
1014 const int ir1 = MIN(ir0 + dr, nr);
1015
1016 for (int ir = ir0; ir < ir1; ++ir) {
1017 // src0 and dst are same shape => same indices
1018 const int i3 = ir/(ne2*ne1);
1019 const int i2 = (ir - i3*ne2*ne1)/ne1;
1020 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1021
1022 ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1023 ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1024 for (int i = 0; i < ne0; i++) {
1025 dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1026 }
1027 }
1028}
1029
1030static void ggml_compute_forward_add1_bf16_bf16(
1031 const ggml_compute_params * params,
1032 ggml_tensor * dst) {
1033
1034 const ggml_tensor * src0 = dst->src[0];
1035 const ggml_tensor * src1 = dst->src[1];
1036
1037 GGML_ASSERT(ggml_are_same_shape(src0, dst));
1038 GGML_ASSERT(ggml_is_scalar(src1));
1039
1040 // scalar to add
1041 const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
1042
1043 const int ith = params->ith;
1044 const int nth = params->nth;
1045
1046 const int nr = ggml_nrows(src0);
1047
1048 GGML_TENSOR_UNARY_OP_LOCALS
1049
1050 GGML_ASSERT(src0->type == GGML_TYPE_BF16);
1051 GGML_ASSERT(src1->type == GGML_TYPE_BF16);
1052 GGML_ASSERT(dst->type == GGML_TYPE_BF16);
1053
1054 GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
1055 GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
1056
1057 // rows per thread
1058 const int dr = (nr + nth - 1)/nth;
1059
1060 // row range for this thread
1061 const int ir0 = dr*ith;
1062 const int ir1 = MIN(ir0 + dr, nr);
1063
1064 for (int ir = ir0; ir < ir1; ++ir) {
1065 // src0 and dst are same shape => same indices
1066 const int i3 = ir/(ne2*ne1);
1067 const int i2 = (ir - i3*ne2*ne1)/ne1;
1068 const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
1069
1070 ggml_bf16_t * dst_ptr = (ggml_bf16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
1071 ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
1072 for (int i = 0; i < ne0; i++) {
1073 dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
1074 }
1075 }
1076}
1077
1078void ggml_compute_forward_add1(
1079 const ggml_compute_params * params,
1080 ggml_tensor * dst) {
1081
1082 const ggml_tensor * src0 = dst->src[0];
1083 const ggml_tensor * src1 = dst->src[1];
1084
1085 switch (src0->type) {
1086 case GGML_TYPE_F32:
1087 {
1088 ggml_compute_forward_add1_f32(params, dst);
1089 } break;
1090 case GGML_TYPE_F16:
1091 {
1092 if (src1->type == GGML_TYPE_F16) {
1093 ggml_compute_forward_add1_f16_f16(params, dst);
1094 }
1095 else if (src1->type == GGML_TYPE_F32) {
1096 ggml_compute_forward_add1_f16_f32(params, dst);
1097 }
1098 else {
1099 GGML_ABORT("fatal error");
1100 }
1101 } break;
1102 case GGML_TYPE_BF16:
1103 {
1104 if (src1->type == GGML_TYPE_BF16) {
1105 ggml_compute_forward_add1_bf16_bf16(params, dst);
1106 }
1107 else if (src1->type == GGML_TYPE_F32) {
1108 ggml_compute_forward_add1_bf16_f32(params, dst);
1109 }
1110 else {
1111 GGML_ABORT("fatal error");
1112 }
1113 } break;
1114 case GGML_TYPE_Q4_0:
1115 case GGML_TYPE_Q4_1:
1116 case GGML_TYPE_Q5_0:
1117 case GGML_TYPE_Q5_1:
1118 case GGML_TYPE_Q8_0:
1119 case GGML_TYPE_Q8_1:
1120 case GGML_TYPE_MXFP4:
1121 case GGML_TYPE_Q2_K:
1122 case GGML_TYPE_Q3_K:
1123 case GGML_TYPE_Q4_K:
1124 case GGML_TYPE_Q5_K:
1125 case GGML_TYPE_Q6_K:
1126 case GGML_TYPE_TQ1_0:
1127 case GGML_TYPE_TQ2_0:
1128 case GGML_TYPE_IQ2_XXS:
1129 case GGML_TYPE_IQ2_XS:
1130 case GGML_TYPE_IQ3_XXS:
1131 case GGML_TYPE_IQ1_S:
1132 case GGML_TYPE_IQ1_M:
1133 case GGML_TYPE_IQ4_NL:
1134 case GGML_TYPE_IQ4_XS:
1135 case GGML_TYPE_IQ3_S:
1136 case GGML_TYPE_IQ2_S:
1137 {
1138 ggml_compute_forward_add1_q_f32(params, dst);
1139 } break;
1140 default:
1141 {
1142 GGML_ABORT("fatal error");
1143 }
1144 }
1145}
1146
1147// ggml_compute_forward_acc
1148
1149static void ggml_compute_forward_acc_f32(
1150 const ggml_compute_params * params,
1151 ggml_tensor * dst) {
1152
1153 const ggml_tensor * src0 = dst->src[0];
1154 const ggml_tensor * src1 = dst->src[1];
1155
1156 GGML_ASSERT(ggml_are_same_shape(src0, dst));
1157 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
1158
1159 // view src0 and dst with these strides and data offset inbytes during acc
1160 // nb0 is implicitly element_size because src0 and dst are contiguous
1161 size_t nb1 = ((int32_t *) dst->op_params)[0];
1162 size_t nb2 = ((int32_t *) dst->op_params)[1];
1163 size_t nb3 = ((int32_t *) dst->op_params)[2];
1164 size_t offset = ((int32_t *) dst->op_params)[3];
1165 bool inplace = (bool) ((int32_t *) dst->op_params)[4];
1166
1167 if (!inplace) {
1168 if (params->ith == 0) {
1169 // memcpy needs to be synchronized across threads to avoid race conditions.
1170 // => do it in INIT phase
1171 memcpy(
1172 ((char *) dst->data),
1173 ((char *) src0->data),
1174 ggml_nbytes(dst));
1175 }
1176 ggml_barrier(params->threadpool);
1177 }
1178
1179 const int ith = params->ith;
1180 const int nth = params->nth;
1181
1182 const int nr = ggml_nrows(src1);
1183 const int nc = src1->ne[0];
1184
1185 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
1186 GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
1187
1188 // src0 and dst as viewed during acc
1189 const size_t nb0 = ggml_element_size(src0);
1190
1191 const size_t nb00 = nb0;
1192 const size_t nb01 = nb1;
1193 const size_t nb02 = nb2;
1194 const size_t nb03 = nb3;
1195
1196 GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0 + (ne11 == 0 ? 0 : ne11-1)*nb1 + (ne12 == 0 ? 0 : ne12-1)*nb2 + (ne13 == 0 ? 0 : ne13-1)*nb3 < ggml_nbytes(dst));
1197 GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
1198
1199 GGML_ASSERT(nb10 == sizeof(float));
1200
1201 // rows per thread
1202 const int dr = (nr + nth - 1)/nth;
1203
1204 // row range for this thread
1205 const int ir0 = dr*ith;
1206 const int ir1 = MIN(ir0 + dr, nr);
1207
1208 for (int ir = ir0; ir < ir1; ++ir) {
1209 // src0 and dst are viewed with shape of src1 and offset
1210 // => same indices
1211 const int i3 = ir/(ne12*ne11);
1212 const int i2 = (ir - i3*ne12*ne11)/ne11;
1213 const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
1214
1215#ifdef GGML_USE_ACCELERATE
1216 vDSP_vadd(
1217 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
1218 (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
1219 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset), 1, nc);
1220#else
1221 ggml_vec_add_f32(nc,
1222 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
1223 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
1224 (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
1225#endif
1226 }
1227}
1228
1229void ggml_compute_forward_acc(
1230 const ggml_compute_params * params,
1231 ggml_tensor * dst) {
1232
1233 const ggml_tensor * src0 = dst->src[0];
1234
1235 switch (src0->type) {
1236 case GGML_TYPE_F32:
1237 {
1238 ggml_compute_forward_acc_f32(params, dst);
1239 } break;
1240 case GGML_TYPE_F16:
1241 case GGML_TYPE_BF16:
1242 case GGML_TYPE_Q4_0:
1243 case GGML_TYPE_Q4_1:
1244 case GGML_TYPE_Q5_0:
1245 case GGML_TYPE_Q5_1:
1246 case GGML_TYPE_Q8_0:
1247 case GGML_TYPE_Q8_1:
1248 case GGML_TYPE_MXFP4:
1249 case GGML_TYPE_Q2_K:
1250 case GGML_TYPE_Q3_K:
1251 case GGML_TYPE_Q4_K:
1252 case GGML_TYPE_Q5_K:
1253 case GGML_TYPE_Q6_K:
1254 case GGML_TYPE_TQ1_0:
1255 case GGML_TYPE_TQ2_0:
1256 case GGML_TYPE_IQ2_XXS:
1257 case GGML_TYPE_IQ2_XS:
1258 case GGML_TYPE_IQ3_XXS:
1259 case GGML_TYPE_IQ1_S:
1260 case GGML_TYPE_IQ1_M:
1261 case GGML_TYPE_IQ4_NL:
1262 case GGML_TYPE_IQ4_XS:
1263 case GGML_TYPE_IQ3_S:
1264 case GGML_TYPE_IQ2_S:
1265 default:
1266 {
1267 GGML_ABORT("fatal error");
1268 }
1269 }
1270}
1271
1272// ggml_compute_forward_sum
1273
1274static void ggml_compute_forward_sum_f32(
1275 const ggml_compute_params * params,
1276 ggml_tensor * dst) {
1277
1278 const ggml_tensor * src0 = dst->src[0];
1279
1280 if (params->ith != 0) {
1281 return;
1282 }
1283
1284 assert(ggml_is_scalar(dst));
1285 assert(src0->nb[0] == sizeof(float));
1286
1287 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1288 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
1289
1290 ggml_float sum = 0;
1291 ggml_float row_sum = 0;
1292
1293 for (int64_t i03 = 0; i03 < ne03; i03++) {
1294 for (int64_t i02 = 0; i02 < ne02; i02++) {
1295 for (int64_t i01 = 0; i01 < ne01; i01++) {
1296 ggml_vec_sum_f32_ggf(ne00,
1297 &row_sum,
1298 (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1299 sum += row_sum;
1300 }
1301 }
1302 }
1303 ((float *) dst->data)[0] = sum;
1304}
1305
1306static void ggml_compute_forward_sum_f16(
1307 const ggml_compute_params * params,
1308 ggml_tensor * dst) {
1309
1310 const ggml_tensor * src0 = dst->src[0];
1311
1312 if (params->ith != 0) {
1313 return;
1314 }
1315
1316 assert(ggml_is_scalar(dst));
1317
1318 assert(src0->nb[0] == sizeof(ggml_fp16_t));
1319
1320 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1321 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
1322
1323 float sum = 0;
1324 float row_sum = 0;
1325
1326 for (int64_t i03 = 0; i03 < ne03; i03++) {
1327 for (int64_t i02 = 0; i02 < ne02; i02++) {
1328 for (int64_t i01 = 0; i01 < ne01; i01++) {
1329 ggml_vec_sum_f16_ggf(ne00,
1330 &row_sum,
1331 (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1332 sum += row_sum;
1333 }
1334 }
1335 }
1336 ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
1337}
1338
1339static void ggml_compute_forward_sum_bf16(
1340 const ggml_compute_params * params,
1341 ggml_tensor * dst) {
1342
1343 const ggml_tensor * src0 = dst->src[0];
1344
1345 if (params->ith != 0) {
1346 return;
1347 }
1348
1349 assert(ggml_is_scalar(dst));
1350
1351 assert(src0->nb[0] == sizeof(ggml_bf16_t));
1352
1353 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
1354 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
1355
1356 float sum = 0;
1357 float row_sum = 0;
1358
1359 for (int64_t i03 = 0; i03 < ne03; i03++) {
1360 for (int64_t i02 = 0; i02 < ne02; i02++) {
1361 for (int64_t i01 = 0; i01 < ne01; i01++) {
1362 ggml_vec_sum_bf16_ggf(ne00,
1363 &row_sum,
1364 (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
1365 sum += row_sum;
1366 }
1367 }
1368 }
1369 ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
1370}
1371
1372void ggml_compute_forward_sum(
1373 const ggml_compute_params * params,
1374 ggml_tensor * dst) {
1375
1376 const ggml_tensor * src0 = dst->src[0];
1377
1378 switch (src0->type) {
1379 case GGML_TYPE_F32:
1380 {
1381 ggml_compute_forward_sum_f32(params, dst);
1382 } break;
1383 case GGML_TYPE_F16:
1384 {
1385 ggml_compute_forward_sum_f16(params, dst);
1386 } break;
1387 case GGML_TYPE_BF16:
1388 {
1389 ggml_compute_forward_sum_bf16(params, dst);
1390 } break;
1391 default:
1392 {
1393 GGML_ABORT("fatal error");
1394 }
1395 }
1396}
1397
1398// ggml_compute_forward_cumsum
1399
1400static void ggml_compute_forward_cumsum_f32(
1401 const ggml_compute_params * params,
1402 ggml_tensor * dst) {
1403
1404 const ggml_tensor * src0 = dst->src[0];
1405
1406 GGML_ASSERT(src0->nb[0] == sizeof(float));
1407 GGML_ASSERT(dst->nb[0] == sizeof(float));
1408
1409 GGML_TENSOR_UNARY_OP_LOCALS
1410
1411 GGML_ASSERT(ne0 == ne00);
1412 GGML_ASSERT(ne1 == ne01);
1413 GGML_ASSERT(ne2 == ne02);
1414 GGML_ASSERT(ne3 == ne03);
1415
1416 const auto [ir0, ir1] = get_thread_range(params, src0);
1417
1418 for (int64_t ir = ir0; ir < ir1; ++ir) {
1419 const int64_t i03 = ir/(ne02*ne01);
1420 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
1421 const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
1422
1423 float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
1424 float * dst_row = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
1425
1426 ggml_vec_cumsum_f32(ne00, dst_row, src_row);
1427 }
1428}
1429
1430void ggml_compute_forward_cumsum(
1431 const ggml_compute_params * params,
1432 ggml_tensor * dst) {
1433
1434 const ggml_tensor * src0 = dst->src[0];
1435
1436 switch (src0->type) {
1437 case GGML_TYPE_F32:
1438 {
1439 ggml_compute_forward_cumsum_f32(params, dst);
1440 } break;
1441 default:
1442 {
1443 GGML_ABORT("fatal error");
1444 }
1445 }
1446}
1447
1448// ggml_compute_forward_sum_rows
1449
1450static void ggml_compute_forward_sum_rows_f32(
1451 const ggml_compute_params * params,
1452 ggml_tensor * dst) {
1453
1454 const ggml_tensor * src0 = dst->src[0];
1455
1456 if (params->ith != 0) {
1457 return;
1458 }
1459
1460 GGML_ASSERT(src0->nb[0] == sizeof(float));
1461 GGML_ASSERT(dst->nb[0] == sizeof(float));
1462
1463 GGML_TENSOR_UNARY_OP_LOCALS
1464
1465 GGML_ASSERT(ne0 == 1);
1466 GGML_ASSERT(ne1 == ne01);
1467 GGML_ASSERT(ne2 == ne02);
1468 GGML_ASSERT(ne3 == ne03);
1469
1470 for (int64_t i3 = 0; i3 < ne03; i3++) {
1471 for (int64_t i2 = 0; i2 < ne02; i2++) {
1472 for (int64_t i1 = 0; i1 < ne01; i1++) {
1473 float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
1474 float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
1475 float row_sum = 0;
1476 ggml_vec_sum_f32(ne00, &row_sum, src_row);
1477 dst_row[0] = row_sum;
1478 }
1479 }
1480 }
1481}
1482
1483void ggml_compute_forward_sum_rows(
1484 const ggml_compute_params * params,
1485 ggml_tensor * dst) {
1486
1487 const ggml_tensor * src0 = dst->src[0];
1488
1489 switch (src0->type) {
1490 case GGML_TYPE_F32:
1491 {
1492 ggml_compute_forward_sum_rows_f32(params, dst);
1493 } break;
1494 default:
1495 {
1496 GGML_ABORT("fatal error");
1497 }
1498 }
1499}
1500
1501// ggml_compute_forward_mean
1502
1503static void ggml_compute_forward_mean_f32(
1504 const ggml_compute_params * params,
1505 ggml_tensor * dst) {
1506
1507 const ggml_tensor * src0 = dst->src[0];
1508
1509 if (params->ith != 0) {
1510 return;
1511 }
1512
1513 assert(src0->nb[0] == sizeof(float));
1514
1515 GGML_TENSOR_UNARY_OP_LOCALS
1516
1517 assert(ne0 == 1);
1518 assert(ne1 == ne01);
1519 assert(ne2 == ne02);
1520 assert(ne3 == ne03);
1521
1522 GGML_UNUSED(ne0);
1523 GGML_UNUSED(ne1);
1524 GGML_UNUSED(ne2);
1525 GGML_UNUSED(ne3);
1526
1527 for (int64_t i03 = 0; i03 < ne03; i03++) {
1528 for (int64_t i02 = 0; i02 < ne02; i02++) {
1529 for (int64_t i01 = 0; i01 < ne01; i01++) {
1530 ggml_vec_sum_f32(ne00,
1531 (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3),
1532 (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
1533
1534 *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
1535 }
1536 }
1537 }
1538}
1539
1540void ggml_compute_forward_mean(
1541 const ggml_compute_params * params,
1542 ggml_tensor * dst) {
1543
1544 const ggml_tensor * src0 = dst->src[0];
1545
1546 switch (src0->type) {
1547 case GGML_TYPE_F32:
1548 {
1549 ggml_compute_forward_mean_f32(params, dst);
1550 } break;
1551 default:
1552 {
1553 GGML_ABORT("fatal error");
1554 }
1555 }
1556}
1557
1558// ggml_compute_forward_argmax
1559
1560static void ggml_compute_forward_argmax_f32(
1561 const ggml_compute_params * params,
1562 ggml_tensor * dst) {
1563
1564 const ggml_tensor * src0 = dst->src[0];
1565
1566 if (params->ith != 0) {
1567 return;
1568 }
1569
1570 assert(src0->nb[0] == sizeof(float));
1571 assert(dst->nb[0] == sizeof(float));
1572
1573 const int64_t ne00 = src0->ne[0];
1574 const int64_t ne01 = src0->ne[1];
1575
1576 const size_t nb01 = src0->nb[1];
1577 const size_t nb0 = dst->nb[0];
1578
1579 for (int64_t i1 = 0; i1 < ne01; i1++) {
1580 float * src = (float *) ((char *) src0->data + i1*nb01);
1581 int32_t * dst_ = (int32_t *) ((char *) dst->data + i1*nb0);
1582 int v = 0;
1583 ggml_vec_argmax_f32(ne00, &v, src);
1584 dst_[0] = v;
1585 }
1586}
1587
1588void ggml_compute_forward_argmax(
1589 const ggml_compute_params * params,
1590 ggml_tensor * dst) {
1591
1592 const ggml_tensor * src0 = dst->src[0];
1593
1594 switch (src0->type) {
1595 case GGML_TYPE_F32:
1596 {
1597 ggml_compute_forward_argmax_f32(params, dst);
1598 } break;
1599 default:
1600 {
1601 GGML_ABORT("fatal error");
1602 }
1603 }
1604}
1605
1606// ggml_compute_forward_count_equal
1607
1608static void ggml_compute_forward_count_equal_i32(
1609 const ggml_compute_params * params,
1610 ggml_tensor * dst) {
1611
1612 const ggml_tensor * src0 = dst->src[0];
1613 const ggml_tensor * src1 = dst->src[1];
1614
1615 GGML_TENSOR_BINARY_OP_LOCALS;
1616
1617 GGML_ASSERT(src0->type == GGML_TYPE_I32);
1618 GGML_ASSERT(src1->type == GGML_TYPE_I32);
1619 GGML_ASSERT(ggml_are_same_shape(src0, src1));
1620 GGML_ASSERT(ggml_is_scalar(dst));
1621 GGML_ASSERT(dst->type == GGML_TYPE_I64);
1622
1623 const int64_t nr = ggml_nrows(src0);
1624
1625 const int ith = params->ith;
1626 const int nth = params->nth;
1627
1628 int64_t * sums = (int64_t *) params->wdata;
1629 int64_t sum_thread = 0;
1630
1631 // rows per thread
1632 const int64_t dr = (nr + nth - 1)/nth;
1633
1634 // row range for this thread
1635 const int64_t ir0 = dr*ith;
1636 const int64_t ir1 = MIN(ir0 + dr, nr);
1637
1638 for (int64_t ir = ir0; ir < ir1; ++ir) {
1639 const int64_t i03 = ir / (ne02*ne01);
1640 const int64_t i02 = (ir - i03*ne03) / ne01;
1641 const int64_t i01 = ir - i03*ne03 - i02*ne02;
1642
1643 const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
1644 const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
1645
1646 for (int64_t i00 = 0; i00 < ne00; ++i00) {
1647 const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
1648 const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
1649
1650 sum_thread += val0 == val1;
1651 }
1652 }
1653 if (ith != 0) {
1654 sums[ith] = sum_thread;
1655 }
1656 ggml_barrier(params->threadpool);
1657
1658 if (ith != 0) {
1659 return;
1660 }
1661
1662 for (int ith_other = 1; ith_other < nth; ++ith_other) {
1663 sum_thread += sums[ith_other];
1664 }
1665 *((int64_t *) dst->data) = sum_thread;
1666}
1667
1668void ggml_compute_forward_count_equal(
1669 const ggml_compute_params * params,
1670 ggml_tensor * dst) {
1671
1672 const ggml_tensor * src0 = dst->src[0];
1673
1674 switch (src0->type) {
1675 case GGML_TYPE_I32:
1676 {
1677 ggml_compute_forward_count_equal_i32(params, dst);
1678 } break;
1679 default:
1680 {
1681 GGML_ABORT("fatal error");
1682 }
1683 }
1684}
1685
1686// ggml_compute_forward_repeat
1687
1688static void ggml_compute_forward_repeat_f32(
1689 const ggml_compute_params * params,
1690 ggml_tensor * dst) {
1691
1692 const ggml_tensor * src0 = dst->src[0];
1693
1694 if (params->ith != 0) {
1695 return;
1696 }
1697
1698 GGML_ASSERT(ggml_can_repeat(src0, dst));
1699
1700 GGML_TENSOR_UNARY_OP_LOCALS
1701
1702 // guaranteed to be an integer due to the check in ggml_can_repeat
1703 const int nr0 = (int)(ne0/ne00);
1704 const int nr1 = (int)(ne1/ne01);
1705 const int nr2 = (int)(ne2/ne02);
1706 const int nr3 = (int)(ne3/ne03);
1707
1708 // TODO: support for transposed / permuted tensors
1709 GGML_ASSERT(nb0 == sizeof(float));
1710 GGML_ASSERT(nb00 == sizeof(float));
1711
1712 // TODO: maybe this is not optimal?
1713 for (int i3 = 0; i3 < nr3; i3++) {
1714 for (int k3 = 0; k3 < ne03; k3++) {
1715 for (int i2 = 0; i2 < nr2; i2++) {
1716 for (int k2 = 0; k2 < ne02; k2++) {
1717 for (int i1 = 0; i1 < nr1; i1++) {
1718 for (int k1 = 0; k1 < ne01; k1++) {
1719 for (int i0 = 0; i0 < nr0; i0++) {
1720 ggml_vec_cpy_f32(ne00,
1721 (float *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0),
1722 (float *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01));
1723 }
1724 }
1725 }
1726 }
1727 }
1728 }
1729 }
1730}
1731
1732static void ggml_compute_forward_repeat_f16(
1733 const ggml_compute_params * params,
1734 ggml_tensor * dst) {
1735
1736 const ggml_tensor * src0 = dst->src[0];
1737
1738 if (params->ith != 0) {
1739 return;
1740 }
1741
1742 GGML_ASSERT(ggml_can_repeat(src0, dst));
1743
1744 GGML_TENSOR_UNARY_OP_LOCALS
1745
1746 // guaranteed to be an integer due to the check in ggml_can_repeat
1747 const int nr0 = (int)(ne0/ne00);
1748 const int nr1 = (int)(ne1/ne01);
1749 const int nr2 = (int)(ne2/ne02);
1750 const int nr3 = (int)(ne3/ne03);
1751
1752 // TODO: support for transposed / permuted tensors
1753 GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
1754 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
1755
1756 // TODO: maybe this is not optimal?
1757 for (int i3 = 0; i3 < nr3; i3++) {
1758 for (int k3 = 0; k3 < ne03; k3++) {
1759 for (int i2 = 0; i2 < nr2; i2++) {
1760 for (int k2 = 0; k2 < ne02; k2++) {
1761 for (int i1 = 0; i1 < nr1; i1++) {
1762 for (int k1 = 0; k1 < ne01; k1++) {
1763 for (int i0 = 0; i0 < nr0; i0++) {
1764 ggml_fp16_t * y = (ggml_fp16_t *) ((char *) dst->data + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0);
1765 ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01);
1766 // ggml_vec_cpy_f16(ne00, y, x)
1767 for (int i = 0; i < ne00; ++i) {
1768 y[i] = x[i];
1769 }
1770 }
1771 }
1772 }
1773 }
1774 }
1775 }
1776 }
1777}
1778
1779void ggml_compute_forward_repeat(
1780 const ggml_compute_params * params,
1781 ggml_tensor * dst) {
1782
1783 const ggml_tensor * src0 = dst->src[0];
1784
1785 switch (src0->type) {
1786 case GGML_TYPE_F16:
1787 case GGML_TYPE_BF16:
1788 case GGML_TYPE_I16:
1789 {
1790 ggml_compute_forward_repeat_f16(params, dst);
1791 } break;
1792 case GGML_TYPE_F32:
1793 case GGML_TYPE_I32:
1794 {
1795 ggml_compute_forward_repeat_f32(params, dst);
1796 } break;
1797 // TODO: templateify the implemenation and support for I64
1798 // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
1799 //case GGML_TYPE_I64:
1800 // {
1801 // ggml_compute_forward_repeat_i64(params, dst);
1802 // } break;
1803 default:
1804 {
1805 GGML_ABORT("fatal error");
1806 }
1807 }
1808}
1809
1810// ggml_compute_forward_repeat_back
1811
1812static void ggml_compute_forward_repeat_back_f32(
1813 const ggml_compute_params * params,
1814 ggml_tensor * dst) {
1815
1816 const ggml_tensor * src0 = dst->src[0];
1817
1818 if (params->ith != 0) {
1819 return;
1820 }
1821
1822 GGML_ASSERT(ggml_can_repeat(dst, src0));
1823
1824 GGML_TENSOR_UNARY_OP_LOCALS
1825
1826 // guaranteed to be an integer due to the check in ggml_can_repeat
1827 const int nr0 = (int)(ne00/ne0);
1828 const int nr1 = (int)(ne01/ne1);
1829 const int nr2 = (int)(ne02/ne2);
1830 const int nr3 = (int)(ne03/ne3);
1831
1832 // TODO: support for transposed / permuted tensors
1833 GGML_ASSERT(nb0 == sizeof(float));
1834 GGML_ASSERT(nb00 == sizeof(float));
1835
1836 if (ggml_is_contiguous(dst)) {
1837 ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
1838 } else {
1839 for (int k3 = 0; k3 < ne3; k3++) {
1840 for (int k2 = 0; k2 < ne2; k2++) {
1841 for (int k1 = 0; k1 < ne1; k1++) {
1842 ggml_vec_set_f32(ne0,
1843 (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
1844 0);
1845 }
1846 }
1847 }
1848 }
1849
1850 // TODO: maybe this is not optimal?
1851 for (int i3 = 0; i3 < nr3; i3++) {
1852 for (int k3 = 0; k3 < ne3; k3++) {
1853 for (int i2 = 0; i2 < nr2; i2++) {
1854 for (int k2 = 0; k2 < ne2; k2++) {
1855 for (int i1 = 0; i1 < nr1; i1++) {
1856 for (int k1 = 0; k1 < ne1; k1++) {
1857 for (int i0 = 0; i0 < nr0; i0++) {
1858 ggml_vec_acc_f32(ne0,
1859 (float *) ((char *) dst->data + ( k3)*nb3 + ( k2)*nb2 + ( k1)*nb1),
1860 (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
1861 }
1862 }
1863 }
1864 }
1865 }
1866 }
1867 }
1868}
1869
1870void ggml_compute_forward_repeat_back(
1871 const ggml_compute_params * params,
1872 ggml_tensor * dst) {
1873
1874 const ggml_tensor * src0 = dst->src[0];
1875
1876 switch (src0->type) {
1877 case GGML_TYPE_F32:
1878 {
1879 ggml_compute_forward_repeat_back_f32(params, dst);
1880 } break;
1881 default:
1882 {
1883 GGML_ABORT("fatal error");
1884 }
1885 }
1886}
1887
1888// ggml_compute_forward_concat
1889
1890static void ggml_compute_forward_concat_any(
1891 const ggml_compute_params * params,
1892 ggml_tensor * dst) {
1893
1894 const ggml_tensor * src0 = dst->src[0];
1895 const ggml_tensor * src1 = dst->src[1];
1896
1897 const size_t len = ggml_type_size(src0->type);
1898
1899 const int ith = params->ith;
1900 const int nth = params->nth;
1901
1902 GGML_TENSOR_BINARY_OP_LOCALS
1903
1904 const int32_t dim = ggml_get_op_params_i32(dst, 0);
1905
1906 GGML_ASSERT(dim >= 0 && dim < 4);
1907
1908 int64_t o[4] = {0, 0, 0, 0};
1909 o[dim] = src0->ne[dim];
1910
1911 const char * x;
1912
1913 // TODO: smarter multi-theading
1914 for (int i3 = 0; i3 < ne3; i3++) {
1915 for (int i2 = ith; i2 < ne2; i2 += nth) {
1916 for (int i1 = 0; i1 < ne1; i1++) {
1917 for (int i0 = 0; i0 < ne0; i0++) {
1918 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1919 x = (const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03;
1920 } else {
1921 x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
1922 }
1923
1924 char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
1925
1926 memcpy(y, x, len);
1927 }
1928 }
1929 }
1930 }
1931}
1932
1933static void ggml_compute_forward_concat_i8(
1934 const ggml_compute_params * params,
1935 ggml_tensor * dst) {
1936
1937 const ggml_tensor * src0 = dst->src[0];
1938 const ggml_tensor * src1 = dst->src[1];
1939
1940 GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
1941
1942 const int ith = params->ith;
1943 const int nth = params->nth;
1944
1945 GGML_TENSOR_BINARY_OP_LOCALS
1946
1947 const int32_t dim = ggml_get_op_params_i32(dst, 0);
1948
1949 GGML_ASSERT(dim >= 0 && dim < 4);
1950
1951 int64_t o[4] = {0, 0, 0, 0};
1952 o[dim] = src0->ne[dim];
1953
1954 const int8_t * x;
1955
1956 // TODO: smarter multi-theading
1957 for (int i3 = 0; i3 < ne3; i3++) {
1958 for (int i2 = ith; i2 < ne2; i2 += nth) {
1959 for (int i1 = 0; i1 < ne1; i1++) {
1960 for (int i0 = 0; i0 < ne0; i0++) {
1961 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
1962 x = (const int8_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
1963 } else {
1964 x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
1965 }
1966
1967 int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
1968
1969 *y = *x;
1970 }
1971 }
1972 }
1973 }
1974}
1975
1976static void ggml_compute_forward_concat_f16(
1977 const ggml_compute_params * params,
1978 ggml_tensor * dst) {
1979
1980 const ggml_tensor * src0 = dst->src[0];
1981 const ggml_tensor * src1 = dst->src[1];
1982
1983 GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
1984
1985 const int ith = params->ith;
1986 const int nth = params->nth;
1987
1988 GGML_TENSOR_BINARY_OP_LOCALS
1989
1990 const int32_t dim = ggml_get_op_params_i32(dst, 0);
1991
1992 GGML_ASSERT(dim >= 0 && dim < 4);
1993
1994 int64_t o[4] = {0, 0, 0, 0};
1995 o[dim] = src0->ne[dim];
1996
1997 const ggml_fp16_t * x;
1998
1999 // TODO: smarter multi-theading
2000 for (int i3 = 0; i3 < ne3; i3++) {
2001 for (int i2 = ith; i2 < ne2; i2 += nth) {
2002 for (int i1 = 0; i1 < ne1; i1++) {
2003 for (int i0 = 0; i0 < ne0; i0++) {
2004 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2005 x = (const ggml_fp16_t *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
2006 } else {
2007 x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2008 }
2009
2010 ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2011
2012 *y = *x;
2013 }
2014 }
2015 }
2016 }
2017}
2018
2019static void ggml_compute_forward_concat_f32(
2020 const ggml_compute_params * params,
2021 ggml_tensor * dst) {
2022
2023 const ggml_tensor * src0 = dst->src[0];
2024 const ggml_tensor * src1 = dst->src[1];
2025
2026 GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
2027
2028 const int ith = params->ith;
2029 const int nth = params->nth;
2030
2031 GGML_TENSOR_BINARY_OP_LOCALS
2032
2033 const int32_t dim = ggml_get_op_params_i32(dst, 0);
2034
2035 GGML_ASSERT(dim >= 0 && dim < 4);
2036
2037 int64_t o[4] = {0, 0, 0, 0};
2038 o[dim] = src0->ne[dim];
2039
2040 const float * x;
2041
2042 // TODO: smarter multi-theading
2043 for (int i3 = 0; i3 < ne3; i3++) {
2044 for (int i2 = ith; i2 < ne2; i2 += nth) {
2045 for (int i1 = 0; i1 < ne1; i1++) {
2046 for (int i0 = 0; i0 < ne0; i0++) {
2047 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
2048 x = (const float *) ((const char *)src0->data + (i0 )*nb00 + (i1 )*nb01 + (i2 )*nb02 + (i3 )*nb03);
2049 } else {
2050 x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
2051 }
2052
2053 float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
2054
2055 *y = *x;
2056 }
2057 }
2058 }
2059 }
2060}
2061
2062void ggml_compute_forward_concat(
2063 const ggml_compute_params * params,
2064 ggml_tensor * dst) {
2065
2066 const ggml_tensor * src0 = dst->src[0];
2067
2068 switch (src0->type) {
2069 case GGML_TYPE_F16:
2070 case GGML_TYPE_BF16:
2071 case GGML_TYPE_I16:
2072 {
2073 ggml_compute_forward_concat_f16(params, dst);
2074 } break;
2075 case GGML_TYPE_I8:
2076 {
2077 ggml_compute_forward_concat_i8(params, dst);
2078 } break;
2079 case GGML_TYPE_F32:
2080 case GGML_TYPE_I32:
2081 {
2082 ggml_compute_forward_concat_f32(params, dst);
2083 } break;
2084 default:
2085 {
2086 ggml_compute_forward_concat_any(params, dst);
2087 }
2088 }
2089}
2090
2091// ggml_compute_forward_gelu
2092
2093static void ggml_compute_forward_gelu_f32(
2094 const ggml_compute_params * params,
2095 ggml_tensor * dst) {
2096
2097 const ggml_tensor * src0 = dst->src[0];
2098
2099 assert(ggml_is_contiguous_rows(src0));
2100 assert(ggml_are_same_shape(src0, dst));
2101
2102 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2103 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2104 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2105 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2106
2107 const int ith = params->ith;
2108 const int nth = params->nth;
2109
2110 const int nc = src0->ne[0];
2111 const int nr = ggml_nrows(src0);
2112
2113 // rows per thread
2114 const int dr = (nr + nth - 1)/nth;
2115
2116 // row range for this thread
2117 const int ir0 = dr*ith;
2118 const int ir1 = MIN(ir0 + dr, nr);
2119
2120 for (int ir = ir0; ir < ir1; ++ir) {
2121 const int i3 = ir/(ne02*ne01);
2122 const int i2 = (ir - i3*ne02*ne01)/ne01;
2123 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2124
2125 ggml_vec_gelu_f32(nc,
2126 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2127 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2128
2129#ifndef NDEBUG
2130 for (int k = 0; k < nc; k++) {
2131 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2132 GGML_UNUSED(x);
2133 assert(!isnan(x));
2134 assert(!isinf(x));
2135 }
2136#endif
2137 }
2138}
2139
2140static void ggml_compute_forward_gelu_f16(
2141 const ggml_compute_params * params,
2142 ggml_tensor * dst) {
2143
2144 const ggml_tensor * src0 = dst->src[0];
2145
2146 assert(ggml_is_contiguous_rows(src0));
2147 assert(ggml_are_same_shape(src0, dst));
2148
2149 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2150 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2151 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2152 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2153
2154 const int ith = params->ith;
2155 const int nth = params->nth;
2156
2157 const int nc = src0->ne[0];
2158 const int nr = ggml_nrows(src0);
2159
2160 // rows per thread
2161 const int dr = (nr + nth - 1)/nth;
2162
2163 // row range for this thread
2164 const int ir0 = dr*ith;
2165 const int ir1 = MIN(ir0 + dr, nr);
2166
2167 for (int ir = ir0; ir < ir1; ++ir) {
2168 const int i3 = ir/(ne02*ne01);
2169 const int i2 = (ir - i3*ne02*ne01)/ne01;
2170 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2171
2172 ggml_vec_gelu_f16(nc,
2173 (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2174 (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2175
2176#ifndef NDEBUG
2177 for (int k = 0; k < nc; k++) {
2178 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2179 const float v = GGML_CPU_FP16_TO_FP32(x);
2180 GGML_UNUSED(v);
2181 assert(!isnan(v));
2182 assert(!isinf(v));
2183 }
2184#endif
2185 }
2186}
2187
2188static void ggml_compute_forward_gelu(
2189 const ggml_compute_params * params,
2190 ggml_tensor * dst) {
2191
2192 const ggml_tensor * src0 = dst->src[0];
2193
2194 switch (src0->type) {
2195 case GGML_TYPE_F32:
2196 {
2197 ggml_compute_forward_gelu_f32(params, dst);
2198 } break;
2199 case GGML_TYPE_F16:
2200 {
2201 ggml_compute_forward_gelu_f16(params, dst);
2202 } break;
2203 default:
2204 {
2205 GGML_ABORT("fatal error");
2206 }
2207 }
2208}
2209
2210// ggml_compute_fill
2211
2212static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2213 const float c = ggml_get_op_params_f32(dst, 0);
2214
2215 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
2216 GGML_TENSOR_LOCALS(size_t, nb, dst, nb);
2217
2218 const auto [ir0, ir1] = get_thread_range(params, dst);
2219
2220 for (int64_t ir = ir0; ir < ir1; ++ir) {
2221 const int64_t i03 = ir/(ne2*ne1);
2222 const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
2223 const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
2224
2225 float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2226
2227 ggml_vec_set_f32(ne0, dst_ptr, c);
2228 }
2229}
2230
2231void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
2232 ggml_compute_forward_fill_f32(params, dst);
2233}
2234
2235// ggml_compute_tri
2236
2237static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
2238 const ggml_tensor * src0 = dst->src[0];
2239
2240 const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
2241
2242 GGML_ASSERT(ggml_is_contiguous(src0));
2243
2244 GGML_TENSOR_UNARY_OP_LOCALS
2245
2246 const auto [ir0, ir1] = get_thread_range(params, src0);
2247
2248 bool (*bipred)(int, int);
2249
2250 switch (ttype) {
2251 case GGML_TRI_TYPE_LOWER: bipred = [](int i, int r) { return i < r; }; break;
2252 case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
2253 case GGML_TRI_TYPE_UPPER: bipred = [](int i, int r) { return i > r; }; break;
2254 case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
2255 default: GGML_ABORT("invalid tri type");
2256 }
2257
2258 for (int64_t ir = ir0; ir < ir1; ++ir) {
2259 const int64_t i03 = ir/(ne02*ne01);
2260 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
2261 const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
2262
2263 const float * src_ptr = (const float *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
2264 float * dst_ptr = ( float *) (( char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
2265
2266 for (int i0 = 0; i0 < ne0; ++i0) {
2267 dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
2268 }
2269 }
2270}
2271
2272void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
2273 const ggml_tensor * src0 = dst->src[0];
2274
2275 switch (src0->type) {
2276 case GGML_TYPE_F32:
2277 {
2278 ggml_compute_forward_tri_f32(params, dst);
2279 } break;
2280 default:
2281 {
2282 GGML_ABORT("fatal error");
2283 }
2284 }
2285}
2286
2287// ggml_compute_forward_gelu_erf
2288
2289static void ggml_compute_forward_gelu_erf_f32(
2290 const ggml_compute_params * params,
2291 ggml_tensor * dst) {
2292
2293 const ggml_tensor * src0 = dst->src[0];
2294
2295 assert(ggml_is_contiguous_rows(src0));
2296 assert(ggml_are_same_shape(src0, dst));
2297
2298 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2299 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2300 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2301 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2302
2303 const int ith = params->ith;
2304 const int nth = params->nth;
2305
2306 const int nc = src0->ne[0];
2307 const int nr = ggml_nrows(src0);
2308
2309 // rows per thread
2310 const int dr = (nr + nth - 1)/nth;
2311
2312 // row range for this thread
2313 const int ir0 = dr*ith;
2314 const int ir1 = MIN(ir0 + dr, nr);
2315
2316 for (int ir = ir0; ir < ir1; ++ir) {
2317 const int i3 = ir/(ne02*ne01);
2318 const int i2 = (ir - i3*ne02*ne01)/ne01;
2319 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2320
2321 ggml_vec_gelu_erf_f32(nc,
2322 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2323 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2324
2325#ifndef NDEBUG
2326 for (int k = 0; k < nc; k++) {
2327 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2328 GGML_UNUSED(x);
2329 assert(!isnan(x));
2330 assert(!isinf(x));
2331 }
2332#endif
2333 }
2334}
2335
2336static void ggml_compute_forward_gelu_erf_f16(
2337 const ggml_compute_params * params,
2338 ggml_tensor * dst) {
2339
2340 const ggml_tensor * src0 = dst->src[0];
2341
2342 assert(ggml_is_contiguous_rows(src0));
2343 assert(ggml_are_same_shape(src0, dst));
2344
2345 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2346 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2347 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2348 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2349
2350 const int ith = params->ith;
2351 const int nth = params->nth;
2352
2353 const int nc = src0->ne[0];
2354 const int nr = ggml_nrows(src0);
2355
2356 // rows per thread
2357 const int dr = (nr + nth - 1)/nth;
2358
2359 // row range for this thread
2360 const int ir0 = dr*ith;
2361 const int ir1 = MIN(ir0 + dr, nr);
2362
2363 for (int ir = ir0; ir < ir1; ++ir) {
2364 const int i3 = ir/(ne02*ne01);
2365 const int i2 = (ir - i3*ne02*ne01)/ne01;
2366 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2367
2368 ggml_vec_gelu_erf_f16(nc,
2369 (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2370 (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2371
2372#ifndef NDEBUG
2373 for (int k = 0; k < nc; k++) {
2374 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2375 const float v = GGML_CPU_FP16_TO_FP32(x);
2376 GGML_UNUSED(v);
2377 assert(!isnan(v));
2378 assert(!isinf(v));
2379 }
2380#endif
2381 }
2382}
2383
2384static void ggml_compute_forward_gelu_erf(
2385 const ggml_compute_params * params,
2386 ggml_tensor * dst) {
2387
2388 const ggml_tensor * src0 = dst->src[0];
2389
2390 switch (src0->type) {
2391 case GGML_TYPE_F32:
2392 {
2393 ggml_compute_forward_gelu_erf_f32(params, dst);
2394 } break;
2395 case GGML_TYPE_F16:
2396 {
2397 ggml_compute_forward_gelu_erf_f16(params, dst);
2398 } break;
2399 default:
2400 {
2401 GGML_ABORT("fatal error");
2402 }
2403 }
2404}
2405
2406// ggml_compute_forward_gelu_quick
2407
2408static void ggml_compute_forward_gelu_quick_f32(
2409 const ggml_compute_params * params,
2410 ggml_tensor * dst) {
2411
2412 const ggml_tensor * src0 = dst->src[0];
2413
2414 assert(ggml_is_contiguous_rows(src0));
2415 assert(ggml_are_same_shape(src0, dst));
2416
2417 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2418 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2419 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2420 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2421
2422 const int ith = params->ith;
2423 const int nth = params->nth;
2424
2425 const int nc = src0->ne[0];
2426 const int nr = ggml_nrows(src0);
2427
2428 // rows per thread
2429 const int dr = (nr + nth - 1)/nth;
2430
2431 // row range for this thread
2432 const int ir0 = dr*ith;
2433 const int ir1 = MIN(ir0 + dr, nr);
2434
2435 for (int ir = ir0; ir < ir1; ++ir) {
2436 const int i3 = ir/(ne02*ne01);
2437 const int i2 = (ir - i3*ne02*ne01)/ne01;
2438 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2439
2440 ggml_vec_gelu_quick_f32(nc,
2441 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2442 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2443
2444#ifndef NDEBUG
2445 for (int k = 0; k < nc; k++) {
2446 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2447 GGML_UNUSED(x);
2448 assert(!isnan(x));
2449 assert(!isinf(x));
2450 }
2451#endif
2452 }
2453}
2454
2455static void ggml_compute_forward_gelu_quick_f16(
2456 const ggml_compute_params * params,
2457 ggml_tensor * dst) {
2458
2459 const ggml_tensor * src0 = dst->src[0];
2460
2461 assert(ggml_is_contiguous_rows(src0));
2462 assert(ggml_are_same_shape(src0, dst));
2463
2464 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2465 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2466 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2467 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2468
2469 const int ith = params->ith;
2470 const int nth = params->nth;
2471
2472 const int nc = src0->ne[0];
2473 const int nr = ggml_nrows(src0);
2474
2475 // rows per thread
2476 const int dr = (nr + nth - 1)/nth;
2477
2478 // row range for this thread
2479 const int ir0 = dr*ith;
2480 const int ir1 = MIN(ir0 + dr, nr);
2481
2482 for (int ir = ir0; ir < ir1; ++ir) {
2483 const int i3 = ir/(ne02*ne01);
2484 const int i2 = (ir - i3*ne02*ne01)/ne01;
2485 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2486
2487 ggml_vec_gelu_quick_f16(nc,
2488 (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2489 (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2490
2491#ifndef NDEBUG
2492 for (int k = 0; k < nc; k++) {
2493 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2494 const float v = GGML_CPU_FP16_TO_FP32(x);
2495 GGML_UNUSED(v);
2496 assert(!isnan(v));
2497 assert(!isinf(v));
2498 }
2499#endif
2500 }
2501}
2502
2503static void ggml_compute_forward_gelu_quick(
2504 const ggml_compute_params * params,
2505 ggml_tensor * dst) {
2506
2507 const ggml_tensor * src0 = dst->src[0];
2508
2509 switch (src0->type) {
2510 case GGML_TYPE_F32:
2511 {
2512 ggml_compute_forward_gelu_quick_f32(params, dst);
2513 } break;
2514 case GGML_TYPE_F16:
2515 {
2516 ggml_compute_forward_gelu_quick_f16(params, dst);
2517 } break;
2518 default:
2519 {
2520 GGML_ABORT("fatal error");
2521 }
2522 }
2523}
2524
2525// ggml_compute_forward_silu
2526
2527static void ggml_compute_forward_silu_f32(
2528 const ggml_compute_params * params,
2529 ggml_tensor * dst) {
2530
2531 const ggml_tensor * src0 = dst->src[0];
2532
2533 assert(ggml_is_contiguous_rows(src0));
2534 assert(ggml_are_same_shape(src0, dst));
2535
2536 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2537 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2538 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2539 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2540
2541 const int ith = params->ith;
2542 const int nth = params->nth;
2543
2544 const int nc = src0->ne[0];
2545 const int nr = ggml_nrows(src0);
2546
2547 // rows per thread
2548 const int dr = (nr + nth - 1)/nth;
2549
2550 // row range for this thread
2551 const int ir0 = dr*ith;
2552 const int ir1 = MIN(ir0 + dr, nr);
2553
2554 for (int ir = ir0; ir < ir1; ++ir) {
2555 const int i3 = ir/(ne02*ne01);
2556 const int i2 = (ir - i3*ne02*ne01)/ne01;
2557 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2558
2559 ggml_vec_silu_f32(nc,
2560 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2561 (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2562
2563#ifndef NDEBUG
2564 for (int k = 0; k < nc; k++) {
2565 const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2566 GGML_UNUSED(x);
2567 assert(!isnan(x));
2568 assert(!isinf(x));
2569 }
2570#endif
2571 }
2572}
2573
2574static void ggml_compute_forward_silu_f16(
2575 const ggml_compute_params * params,
2576 ggml_tensor * dst) {
2577
2578 const ggml_tensor * src0 = dst->src[0];
2579
2580 assert(ggml_is_contiguous_rows(src0));
2581 assert(ggml_are_same_shape(src0, dst));
2582
2583 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
2584 GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
2585 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
2586 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
2587
2588 const int ith = params->ith;
2589 const int nth = params->nth;
2590
2591 const int nc = src0->ne[0];
2592 const int nr = ggml_nrows(src0);
2593
2594 // rows per thread
2595 const int dr = (nr + nth - 1)/nth;
2596
2597 // row range for this thread
2598 const int ir0 = dr*ith;
2599 const int ir1 = MIN(ir0 + dr, nr);
2600
2601 for (int ir = ir0; ir < ir1; ++ir) {
2602 const int i3 = ir/(ne02*ne01);
2603 const int i2 = (ir - i3*ne02*ne01)/ne01;
2604 const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
2605
2606 ggml_vec_silu_f16(nc,
2607 (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
2608 (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
2609
2610#ifndef NDEBUG
2611 for (int k = 0; k < nc; k++) {
2612 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
2613 const float v = GGML_CPU_FP16_TO_FP32(x);
2614 GGML_UNUSED(v);
2615 assert(!isnan(v));
2616 assert(!isinf(v));
2617 }
2618#endif
2619 }
2620}
2621
2622static void ggml_compute_forward_silu(
2623 const ggml_compute_params * params,
2624 ggml_tensor * dst) {
2625
2626 const ggml_tensor * src0 = dst->src[0];
2627
2628 switch (src0->type) {
2629 case GGML_TYPE_F32:
2630 {
2631 ggml_compute_forward_silu_f32(params, dst);
2632 } break;
2633 case GGML_TYPE_F16:
2634 {
2635 ggml_compute_forward_silu_f16(params, dst);
2636 } break;
2637 default:
2638 {
2639 GGML_ABORT("fatal error");
2640 }
2641 }
2642}
2643// ggml_compute_forward_leaky_relu
2644
2645static void ggml_compute_forward_leaky_relu_f32(
2646 const ggml_compute_params * params,
2647 ggml_tensor * dst) {
2648
2649 const ggml_tensor * src0 = dst->src[0];
2650
2651 if (params->ith != 0) {
2652 return;
2653 }
2654
2655 assert(ggml_is_contiguous_1(src0));
2656 assert(ggml_is_contiguous_1(dst));
2657 assert(ggml_are_same_shape(src0, dst));
2658
2659 const int n = ggml_nrows(src0);
2660 const int nc = src0->ne[0];
2661
2662 float negative_slope;
2663 memcpy(&negative_slope, dst->op_params, sizeof(float));
2664
2665 assert(dst->nb[0] == sizeof(float));
2666 assert(src0->nb[0] == sizeof(float));
2667
2668 for (int i = 0; i < n; i++) {
2669 ggml_vec_leaky_relu_f32(nc,
2670 (float *) ((char *) dst->data + i*( dst->nb[1])),
2671 (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2672 }
2673}
2674
2675static void ggml_compute_forward_leaky_relu_f16(
2676 const ggml_compute_params * params,
2677 ggml_tensor * dst) {
2678
2679 const ggml_tensor * src0 = dst->src[0];
2680
2681 if (params->ith != 0) {
2682 return;
2683 }
2684
2685 assert(ggml_is_contiguous_1(src0));
2686 assert(ggml_is_contiguous_1(dst));
2687 assert(ggml_are_same_shape(src0, dst));
2688
2689 const int n = ggml_nrows(src0);
2690 const int nc = src0->ne[0];
2691
2692 float negative_slope;
2693 memcpy(&negative_slope, dst->op_params, sizeof(float));
2694
2695 assert(dst->nb[0] == sizeof(ggml_fp16_t));
2696 assert(src0->nb[0] == sizeof(ggml_fp16_t));
2697
2698 for (int i = 0; i < n; i++) {
2699 ggml_vec_leaky_relu_f16(nc,
2700 (ggml_fp16_t *) ((char *) dst->data + i*( dst->nb[1])),
2701 (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
2702 }
2703}
2704
2705void ggml_compute_forward_leaky_relu(
2706 const ggml_compute_params * params,
2707 ggml_tensor * dst) {
2708
2709 const ggml_tensor * src0 = dst->src[0];
2710
2711 switch (src0->type) {
2712 case GGML_TYPE_F32:
2713 {
2714 ggml_compute_forward_leaky_relu_f32(params, dst);
2715 } break;
2716 case GGML_TYPE_F16:
2717 {
2718 ggml_compute_forward_leaky_relu_f16(params, dst);
2719 } break;
2720 default:
2721 {
2722 GGML_ABORT("fatal error");
2723 }
2724 }
2725}
2726
2727// ggml_compute_forward_silu_back
2728
2729static void ggml_compute_forward_silu_back_f32(
2730 const ggml_compute_params * params,
2731 ggml_tensor * dst) {
2732
2733 const ggml_tensor * grad = dst->src[0];
2734 const ggml_tensor * src1 = dst->src[1];
2735
2736 assert(ggml_is_contiguous_1(grad));
2737 assert(ggml_is_contiguous_1(src1));
2738 assert(ggml_is_contiguous_1(dst));
2739 assert(ggml_are_same_shape(src1, dst));
2740 assert(ggml_are_same_shape(src1, grad));
2741
2742 const int ith = params->ith;
2743 const int nth = params->nth;
2744
2745 const int nc = src1->ne[0];
2746 const int nr = ggml_nrows(src1);
2747
2748 // rows per thread
2749 const int dr = (nr + nth - 1)/nth;
2750
2751 // row range for this thread
2752 const int ir0 = dr*ith;
2753 const int ir1 = MIN(ir0 + dr, nr);
2754
2755 for (int i1 = ir0; i1 < ir1; i1++) {
2756 ggml_vec_silu_backward_f32(nc,
2757 (float *) ((char *) dst->data + i1*( dst->nb[1])),
2758 (float *) ((char *) src1->data + i1*(src1->nb[1])),
2759 (float *) ((char *) grad->data + i1*(grad->nb[1])));
2760
2761#ifndef NDEBUG
2762 for (int k = 0; k < nc; k++) {
2763 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2764 GGML_UNUSED(x);
2765 assert(!isnan(x));
2766 assert(!isinf(x));
2767 }
2768#endif
2769 }
2770}
2771
2772static void ggml_compute_forward_silu_back_f16(
2773 const ggml_compute_params * params,
2774 ggml_tensor * dst) {
2775
2776 const ggml_tensor * grad = dst->src[0];
2777 const ggml_tensor * src1 = dst->src[1];
2778
2779 assert(ggml_is_contiguous_1(grad));
2780 assert(ggml_is_contiguous_1(src1));
2781 assert(ggml_is_contiguous_1(dst));
2782 assert(ggml_are_same_shape(src1, dst));
2783 assert(ggml_are_same_shape(src1, grad));
2784
2785 const int ith = params->ith;
2786 const int nth = params->nth;
2787
2788 const int nc = src1->ne[0];
2789 const int nr = ggml_nrows(src1);
2790
2791 // rows per thread
2792 const int dr = (nr + nth - 1)/nth;
2793
2794 // row range for this thread
2795 const int ir0 = dr*ith;
2796 const int ir1 = MIN(ir0 + dr, nr);
2797
2798 for (int i1 = ir0; i1 < ir1; i1++) {
2799 ggml_vec_silu_backward_f16(nc,
2800 (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
2801 (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
2802 (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
2803
2804 #ifndef NDEBUG
2805 for (int k = 0; k < nc; k++) {
2806 const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2807 const float v = GGML_CPU_FP16_TO_FP32(x);
2808 GGML_UNUSED(v);
2809 assert(!isnan(v));
2810 assert(!isinf(v));
2811 }
2812 #endif
2813 }
2814}
2815
2816void ggml_compute_forward_silu_back(
2817 const ggml_compute_params * params,
2818 ggml_tensor * dst) {
2819
2820 const ggml_tensor * src0 = dst->src[0];
2821
2822 switch (src0->type) {
2823 case GGML_TYPE_F32:
2824 {
2825 ggml_compute_forward_silu_back_f32(params, dst);
2826 } break;
2827 case GGML_TYPE_F16:
2828 {
2829 ggml_compute_forward_silu_back_f16(params, dst);
2830 } break;
2831 default:
2832 {
2833 GGML_ABORT("fatal error");
2834 }
2835 }
2836}
2837
2838// ggml_compute_forward_reglu
2839
2840static void ggml_compute_forward_reglu_f32(
2841 const ggml_compute_params * params,
2842 ggml_tensor * dst) {
2843
2844 const ggml_tensor * src0 = dst->src[0];
2845 const ggml_tensor * src1 = dst->src[1];
2846 char * src0_d = (char *) src0->data;
2847 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2848 const size_t src0_o = src0->nb[1];
2849 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2850
2851 GGML_ASSERT(ggml_is_contiguous_1(src0));
2852 GGML_ASSERT(ggml_is_contiguous_1(dst));
2853
2854 if (src1) {
2855 GGML_ASSERT(ggml_is_contiguous_1(src1));
2856 GGML_ASSERT(src0->type == src1->type);
2857 }
2858
2859 const int ith = params->ith;
2860 const int nth = params->nth;
2861
2862 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2863 const int nr = ggml_nrows(src0);
2864
2865 GGML_ASSERT(dst->ne[0] == nc);
2866 GGML_ASSERT(ggml_nrows(dst) == nr);
2867
2868 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2869
2870 // rows per thread
2871 const int dr = (nr + nth - 1)/nth;
2872
2873 // row range for this thread
2874 const int ir0 = dr*ith;
2875 const int ir1 = MIN(ir0 + dr, nr);
2876
2877 for (int i1 = ir0; i1 < ir1; i1++) {
2878 float * src0_p = (float *) (src0_d + i1*src0_o);
2879 float * src1_p = (float *) (src1_d + i1*src1_o);
2880
2881 if (!src1) {
2882 src0_p += swapped ? nc : 0;
2883 src1_p += swapped ? 0 : nc;
2884 }
2885
2886 ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2887
2888#ifndef NDEBUG
2889 for (int k = 0; k < nc; k++) {
2890 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2891 GGML_UNUSED(x);
2892 assert(!isnan(x));
2893 assert(!isinf(x));
2894 }
2895#endif
2896 }
2897}
2898
2899static void ggml_compute_forward_reglu_f16(
2900 const ggml_compute_params * params,
2901 ggml_tensor * dst) {
2902
2903 const ggml_tensor * src0 = dst->src[0];
2904 const ggml_tensor * src1 = dst->src[1];
2905 char * src0_d = (char *) src0->data;
2906 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2907 const size_t src0_o = src0->nb[1];
2908 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2909
2910 GGML_ASSERT(ggml_is_contiguous_1(src0));
2911 GGML_ASSERT(ggml_is_contiguous_1(dst));
2912
2913 if (src1) {
2914 GGML_ASSERT(ggml_is_contiguous_1(src1));
2915 GGML_ASSERT(src0->type == src1->type);
2916 }
2917
2918 const int ith = params->ith;
2919 const int nth = params->nth;
2920
2921 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
2922 const int nr = ggml_nrows(src0);
2923
2924 GGML_ASSERT(dst->ne[0] == nc);
2925 GGML_ASSERT(ggml_nrows(dst) == nr);
2926
2927 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
2928
2929 // rows per thread
2930 const int dr = (nr + nth - 1)/nth;
2931
2932 // row range for this thread
2933 const int ir0 = dr*ith;
2934 const int ir1 = MIN(ir0 + dr, nr);
2935
2936 for (int i1 = ir0; i1 < ir1; i1++) {
2937 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
2938 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
2939
2940 if (!src1) {
2941 src0_p += swapped ? nc : 0;
2942 src1_p += swapped ? 0 : nc;
2943 }
2944
2945 ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
2946
2947#ifndef NDEBUG
2948 for (int k = 0; k < nc; k++) {
2949 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
2950 const float v = GGML_FP16_TO_FP32(x);
2951 GGML_UNUSED(v);
2952 assert(!isnan(v));
2953 assert(!isinf(v));
2954 }
2955#endif
2956 }
2957}
2958
2959static void ggml_compute_forward_reglu(
2960 const ggml_compute_params * params,
2961 ggml_tensor * dst) {
2962
2963 const ggml_tensor * src0 = dst->src[0];
2964
2965 switch (src0->type) {
2966 case GGML_TYPE_F32:
2967 {
2968 ggml_compute_forward_reglu_f32(params, dst);
2969 } break;
2970 case GGML_TYPE_F16:
2971 {
2972 ggml_compute_forward_reglu_f16(params, dst);
2973 } break;
2974 default:
2975 {
2976 GGML_ABORT("fatal error");
2977 }
2978 }
2979}
2980
2981// ggml_compute_forward_geglu
2982
2983static void ggml_compute_forward_geglu_f32(
2984 const ggml_compute_params * params,
2985 ggml_tensor * dst) {
2986
2987 const ggml_tensor * src0 = dst->src[0];
2988 const ggml_tensor * src1 = dst->src[1];
2989 char * src0_d = (char *) src0->data;
2990 char * src1_d = (char *) (src1 ? src1->data : src0->data);
2991 const size_t src0_o = src0->nb[1];
2992 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
2993
2994 GGML_ASSERT(ggml_is_contiguous_1(src0));
2995 GGML_ASSERT(ggml_is_contiguous_1(dst));
2996
2997 if (src1) {
2998 GGML_ASSERT(ggml_is_contiguous_1(src1));
2999 GGML_ASSERT(src0->type == src1->type);
3000 }
3001
3002 const int ith = params->ith;
3003 const int nth = params->nth;
3004
3005 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3006 const int nr = ggml_nrows(src0);
3007
3008 GGML_ASSERT(dst->ne[0] == nc);
3009 GGML_ASSERT(ggml_nrows(dst) == nr);
3010
3011 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3012
3013 // rows per thread
3014 const int dr = (nr + nth - 1)/nth;
3015
3016 // row range for this thread
3017 const int ir0 = dr*ith;
3018 const int ir1 = MIN(ir0 + dr, nr);
3019
3020 for (int i1 = ir0; i1 < ir1; i1++) {
3021 float * src0_p = (float *) (src0_d + i1*src0_o);
3022 float * src1_p = (float *) (src1_d + i1*src1_o);
3023
3024 if (!src1) {
3025 src0_p += swapped ? nc : 0;
3026 src1_p += swapped ? 0 : nc;
3027 }
3028
3029 ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3030
3031#ifndef NDEBUG
3032 for (int k = 0; k < nc; k++) {
3033 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3034 GGML_UNUSED(x);
3035 assert(!isnan(x));
3036 assert(!isinf(x));
3037 }
3038#endif
3039 }
3040}
3041
3042static void ggml_compute_forward_geglu_f16(
3043 const ggml_compute_params * params,
3044 ggml_tensor * dst) {
3045
3046 const ggml_tensor * src0 = dst->src[0];
3047 const ggml_tensor * src1 = dst->src[1];
3048 char * src0_d = (char *) src0->data;
3049 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3050 const size_t src0_o = src0->nb[1];
3051 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3052
3053 GGML_ASSERT(ggml_is_contiguous_1(src0));
3054 GGML_ASSERT(ggml_is_contiguous_1(dst));
3055
3056 if (src1) {
3057 GGML_ASSERT(ggml_is_contiguous_1(src1));
3058 GGML_ASSERT(src0->type == src1->type);
3059 }
3060
3061 const int ith = params->ith;
3062 const int nth = params->nth;
3063
3064 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3065 const int nr = ggml_nrows(src0);
3066
3067 GGML_ASSERT(dst->ne[0] == nc);
3068 GGML_ASSERT(ggml_nrows(dst) == nr);
3069
3070 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3071
3072 // rows per thread
3073 const int dr = (nr + nth - 1)/nth;
3074
3075 // row range for this thread
3076 const int ir0 = dr*ith;
3077 const int ir1 = MIN(ir0 + dr, nr);
3078
3079 for (int i1 = ir0; i1 < ir1; i1++) {
3080 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3081 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3082
3083 if (!src1) {
3084 src0_p += swapped ? nc : 0;
3085 src1_p += swapped ? 0 : nc;
3086 }
3087
3088 ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3089
3090#ifndef NDEBUG
3091 for (int k = 0; k < nc; k++) {
3092 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3093 const float v = GGML_FP16_TO_FP32(x);
3094 GGML_UNUSED(v);
3095 assert(!isnan(v));
3096 assert(!isinf(v));
3097 }
3098#endif
3099 }
3100}
3101
3102static void ggml_compute_forward_geglu(
3103 const ggml_compute_params * params,
3104 ggml_tensor * dst) {
3105
3106 const ggml_tensor * src0 = dst->src[0];
3107
3108 switch (src0->type) {
3109 case GGML_TYPE_F32:
3110 {
3111 ggml_compute_forward_geglu_f32(params, dst);
3112 } break;
3113 case GGML_TYPE_F16:
3114 {
3115 ggml_compute_forward_geglu_f16(params, dst);
3116 } break;
3117 default:
3118 {
3119 GGML_ABORT("fatal error");
3120 }
3121 }
3122}
3123
3124// ggml_compute_forward_swiglu
3125
3126static void ggml_compute_forward_swiglu_f32(
3127 const ggml_compute_params * params,
3128 ggml_tensor * dst) {
3129
3130 const ggml_tensor * src0 = dst->src[0];
3131 const ggml_tensor * src1 = dst->src[1];
3132 char * src0_d = (char *) src0->data;
3133 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3134 const size_t src0_o = src0->nb[1];
3135 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3136
3137 GGML_ASSERT(ggml_is_contiguous_1(src0));
3138 GGML_ASSERT(ggml_is_contiguous_1(dst));
3139
3140 if (src1) {
3141 GGML_ASSERT(ggml_is_contiguous_1(src1));
3142 GGML_ASSERT(src0->type == src1->type);
3143 }
3144
3145 const int ith = params->ith;
3146 const int nth = params->nth;
3147
3148 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3149 const int nr = ggml_nrows(src0);
3150
3151 GGML_ASSERT(dst->ne[0] == nc);
3152 GGML_ASSERT(ggml_nrows(dst) == nr);
3153
3154 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3155
3156 // rows per thread
3157 const int dr = (nr + nth - 1)/nth;
3158
3159 // row range for this thread
3160 const int ir0 = dr*ith;
3161 const int ir1 = MIN(ir0 + dr, nr);
3162
3163 for (int i1 = ir0; i1 < ir1; i1++) {
3164 float * src0_p = (float *) (src0_d + i1*src0_o);
3165 float * src1_p = (float *) (src1_d + i1*src1_o);
3166
3167 if (!src1) {
3168 src0_p += swapped ? nc : 0;
3169 src1_p += swapped ? 0 : nc;
3170 }
3171
3172 ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3173
3174#ifndef NDEBUG
3175 for (int k = 0; k < nc; k++) {
3176 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3177 GGML_UNUSED(x);
3178 assert(!isnan(x));
3179 assert(!isinf(x));
3180 }
3181#endif
3182 }
3183}
3184
3185static void ggml_compute_forward_swiglu_f16(
3186 const ggml_compute_params * params,
3187 ggml_tensor * dst) {
3188
3189 const ggml_tensor * src0 = dst->src[0];
3190 const ggml_tensor * src1 = dst->src[1];
3191 char * src0_d = (char *) src0->data;
3192 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3193 const size_t src0_o = src0->nb[1];
3194 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3195
3196 GGML_ASSERT(ggml_is_contiguous_1(src0));
3197 GGML_ASSERT(ggml_is_contiguous_1(dst));
3198
3199 if (src1) {
3200 GGML_ASSERT(ggml_is_contiguous_1(src1));
3201 GGML_ASSERT(src0->type == src1->type);
3202 }
3203
3204 const int ith = params->ith;
3205 const int nth = params->nth;
3206
3207 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3208 const int nr = ggml_nrows(src0);
3209
3210 GGML_ASSERT(dst->ne[0] == nc);
3211 GGML_ASSERT(ggml_nrows(dst) == nr);
3212
3213 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3214
3215 // rows per thread
3216 const int dr = (nr + nth - 1)/nth;
3217
3218 // row range for this thread
3219 const int ir0 = dr*ith;
3220 const int ir1 = MIN(ir0 + dr, nr);
3221
3222 for (int i1 = ir0; i1 < ir1; i1++) {
3223 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3224 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3225
3226 if (!src1) {
3227 src0_p += swapped ? nc : 0;
3228 src1_p += swapped ? 0 : nc;
3229 }
3230
3231 ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3232
3233#ifndef NDEBUG
3234 for (int k = 0; k < nc; k++) {
3235 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3236 const float v = GGML_FP16_TO_FP32(x);
3237 GGML_UNUSED(v);
3238 assert(!isnan(v));
3239 assert(!isinf(v));
3240 }
3241#endif
3242 }
3243}
3244
3245static void ggml_compute_forward_swiglu(
3246 const ggml_compute_params * params,
3247 ggml_tensor * dst) {
3248
3249 const ggml_tensor * src0 = dst->src[0];
3250
3251 switch (src0->type) {
3252 case GGML_TYPE_F32:
3253 {
3254 ggml_compute_forward_swiglu_f32(params, dst);
3255 } break;
3256 case GGML_TYPE_F16:
3257 {
3258 ggml_compute_forward_swiglu_f16(params, dst);
3259 } break;
3260 default:
3261 {
3262 GGML_ABORT("fatal error");
3263 }
3264 }
3265}
3266
3267// ggml_compute_forward_swiglu_oai
3268
3269static void ggml_compute_forward_swiglu_oai_f32(
3270 const ggml_compute_params * params,
3271 ggml_tensor * dst) {
3272
3273 const ggml_tensor * src0 = dst->src[0];
3274 const ggml_tensor * src1 = dst->src[1];
3275 char * src0_d = (char *) src0->data;
3276 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3277 const size_t src0_o = src0->nb[1];
3278 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3279
3280 GGML_ASSERT(ggml_is_contiguous_1(src0));
3281 GGML_ASSERT(ggml_is_contiguous_1(dst));
3282
3283 if (src1) {
3284 GGML_ASSERT(ggml_is_contiguous_1(src1));
3285 GGML_ASSERT(src0->type == src1->type);
3286 }
3287
3288 const int ith = params->ith;
3289 const int nth = params->nth;
3290
3291 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3292 const int nr = ggml_nrows(src0);
3293
3294 GGML_ASSERT(dst->ne[0] == nc);
3295 GGML_ASSERT(ggml_nrows(dst) == nr);
3296
3297 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3298 const float alpha = ggml_get_op_params_f32(dst, 2);
3299 const float limit = ggml_get_op_params_f32(dst, 3);
3300
3301 // rows per thread
3302 const int dr = (nr + nth - 1)/nth;
3303
3304 // row range for this thread
3305 const int ir0 = dr*ith;
3306 const int ir1 = MIN(ir0 + dr, nr);
3307
3308 for (int i1 = ir0; i1 < ir1; i1++) {
3309 float * src0_p = (float *) (src0_d + i1*src0_o);
3310 float * src1_p = (float *) (src1_d + i1*src1_o);
3311 float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
3312
3313 if (!src1) {
3314 src0_p += swapped ? nc : 0;
3315 src1_p += swapped ? 0 : nc;
3316 }
3317
3318 for (int k = 0; k < nc; k++) {
3319 const float x = std::min(src0_p[k], limit);
3320 const float y = std::clamp(src1_p[k], -limit, limit);
3321 const float out_glu = x / (1.f + expf(alpha * (-x)));
3322 dst_p[k] = out_glu * (y + 1.f);
3323 }
3324
3325#ifndef NDEBUG
3326 for (int k = 0; k < nc; k++) {
3327 const float x = dst_p[k];
3328 GGML_UNUSED(x);
3329 assert(!isnan(x));
3330 assert(!isinf(x));
3331 }
3332#endif
3333 }
3334}
3335
3336static void ggml_compute_forward_swiglu_oai(
3337 const ggml_compute_params * params,
3338 ggml_tensor * dst) {
3339
3340 const ggml_tensor * src0 = dst->src[0];
3341
3342 switch (src0->type) {
3343 case GGML_TYPE_F32:
3344 {
3345 ggml_compute_forward_swiglu_oai_f32(params, dst);
3346 } break;
3347 default:
3348 {
3349 GGML_ABORT("fatal error");
3350 }
3351 }
3352}
3353
3354// ggml_compute_forward_geglu_erf
3355
3356static void ggml_compute_forward_geglu_erf_f32(
3357 const ggml_compute_params * params,
3358 ggml_tensor * dst) {
3359
3360 const ggml_tensor * src0 = dst->src[0];
3361 const ggml_tensor * src1 = dst->src[1];
3362 char * src0_d = (char *) src0->data;
3363 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3364 const size_t src0_o = src0->nb[1];
3365 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3366
3367 GGML_ASSERT(ggml_is_contiguous_1(src0));
3368 GGML_ASSERT(ggml_is_contiguous_1(dst));
3369
3370 if (src1) {
3371 GGML_ASSERT(ggml_is_contiguous_1(src1));
3372 GGML_ASSERT(src0->type == src1->type);
3373 }
3374
3375 const int ith = params->ith;
3376 const int nth = params->nth;
3377
3378 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3379 const int nr = ggml_nrows(src0);
3380
3381 GGML_ASSERT(dst->ne[0] == nc);
3382 GGML_ASSERT(ggml_nrows(dst) == nr);
3383
3384 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3385
3386 // rows per thread
3387 const int dr = (nr + nth - 1)/nth;
3388
3389 // row range for this thread
3390 const int ir0 = dr*ith;
3391 const int ir1 = MIN(ir0 + dr, nr);
3392
3393 for (int i1 = ir0; i1 < ir1; i1++) {
3394 float * src0_p = (float *) (src0_d + i1*src0_o);
3395 float * src1_p = (float *) (src1_d + i1*src1_o);
3396
3397 if (!src1) {
3398 src0_p += swapped ? nc : 0;
3399 src1_p += swapped ? 0 : nc;
3400 }
3401
3402 ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3403
3404#ifndef NDEBUG
3405 for (int k = 0; k < nc; k++) {
3406 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3407 GGML_UNUSED(x);
3408 assert(!isnan(x));
3409 assert(!isinf(x));
3410 }
3411#endif
3412 }
3413}
3414
3415static void ggml_compute_forward_geglu_erf_f16(
3416 const ggml_compute_params * params,
3417 ggml_tensor * dst) {
3418
3419 const ggml_tensor * src0 = dst->src[0];
3420 const ggml_tensor * src1 = dst->src[1];
3421 char * src0_d = (char *) src0->data;
3422 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3423 const size_t src0_o = src0->nb[1];
3424 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3425
3426 GGML_ASSERT(ggml_is_contiguous_1(src0));
3427 GGML_ASSERT(ggml_is_contiguous_1(dst));
3428
3429 if (src1) {
3430 GGML_ASSERT(ggml_is_contiguous_1(src1));
3431 GGML_ASSERT(src0->type == src1->type);
3432 }
3433
3434 const int ith = params->ith;
3435 const int nth = params->nth;
3436
3437 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3438 const int nr = ggml_nrows(src0);
3439
3440 GGML_ASSERT(dst->ne[0] == nc);
3441 GGML_ASSERT(ggml_nrows(dst) == nr);
3442
3443 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3444
3445 // rows per thread
3446 const int dr = (nr + nth - 1)/nth;
3447
3448 // row range for this thread
3449 const int ir0 = dr*ith;
3450 const int ir1 = MIN(ir0 + dr, nr);
3451
3452 for (int i1 = ir0; i1 < ir1; i1++) {
3453 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3454 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3455
3456 if (!src1) {
3457 src0_p += swapped ? nc : 0;
3458 src1_p += swapped ? 0 : nc;
3459 }
3460
3461 ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3462
3463#ifndef NDEBUG
3464 for (int k = 0; k < nc; k++) {
3465 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3466 const float v = GGML_FP16_TO_FP32(x);
3467 GGML_UNUSED(v);
3468 assert(!isnan(v));
3469 assert(!isinf(v));
3470 }
3471#endif
3472 }
3473}
3474
3475static void ggml_compute_forward_geglu_erf(
3476 const ggml_compute_params * params,
3477 ggml_tensor * dst) {
3478
3479 const ggml_tensor * src0 = dst->src[0];
3480
3481 switch (src0->type) {
3482 case GGML_TYPE_F32:
3483 {
3484 ggml_compute_forward_geglu_erf_f32(params, dst);
3485 } break;
3486 case GGML_TYPE_F16:
3487 {
3488 ggml_compute_forward_geglu_erf_f16(params, dst);
3489 } break;
3490 default:
3491 {
3492 GGML_ABORT("fatal error");
3493 }
3494 }
3495}
3496
3497// ggml_compute_forward_geglu_quick
3498
3499static void ggml_compute_forward_geglu_quick_f32(
3500 const ggml_compute_params * params,
3501 ggml_tensor * dst) {
3502
3503 const ggml_tensor * src0 = dst->src[0];
3504 const ggml_tensor * src1 = dst->src[1];
3505 char * src0_d = (char *) src0->data;
3506 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3507 const size_t src0_o = src0->nb[1];
3508 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3509
3510 GGML_ASSERT(ggml_is_contiguous_1(src0));
3511 GGML_ASSERT(ggml_is_contiguous_1(dst));
3512
3513 if (src1) {
3514 GGML_ASSERT(ggml_is_contiguous_1(src1));
3515 GGML_ASSERT(src0->type == src1->type);
3516 }
3517
3518 const int ith = params->ith;
3519 const int nth = params->nth;
3520
3521 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3522 const int nr = ggml_nrows(src0);
3523
3524 GGML_ASSERT(dst->ne[0] == nc);
3525 GGML_ASSERT(ggml_nrows(dst) == nr);
3526
3527 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3528
3529 // rows per thread
3530 const int dr = (nr + nth - 1)/nth;
3531
3532 // row range for this thread
3533 const int ir0 = dr*ith;
3534 const int ir1 = MIN(ir0 + dr, nr);
3535
3536 for (int i1 = ir0; i1 < ir1; i1++) {
3537 float * src0_p = (float *) (src0_d + i1*src0_o);
3538 float * src1_p = (float *) (src1_d + i1*src1_o);
3539
3540 if (!src1) {
3541 src0_p += swapped ? nc : 0;
3542 src1_p += swapped ? 0 : nc;
3543 }
3544
3545 ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3546
3547#ifndef NDEBUG
3548 for (int k = 0; k < nc; k++) {
3549 const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3550 GGML_UNUSED(x);
3551 assert(!isnan(x));
3552 assert(!isinf(x));
3553 }
3554#endif
3555 }
3556}
3557
3558static void ggml_compute_forward_geglu_quick_f16(
3559 const ggml_compute_params * params,
3560 ggml_tensor * dst) {
3561
3562 const ggml_tensor * src0 = dst->src[0];
3563 const ggml_tensor * src1 = dst->src[1];
3564 char * src0_d = (char *) src0->data;
3565 char * src1_d = (char *) (src1 ? src1->data : src0->data);
3566 const size_t src0_o = src0->nb[1];
3567 const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
3568
3569 GGML_ASSERT(ggml_is_contiguous_1(src0));
3570 GGML_ASSERT(ggml_is_contiguous_1(dst));
3571
3572 if (src1) {
3573 GGML_ASSERT(ggml_is_contiguous_1(src1));
3574 GGML_ASSERT(src0->type == src1->type);
3575 }
3576
3577 const int ith = params->ith;
3578 const int nth = params->nth;
3579
3580 const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
3581 const int nr = ggml_nrows(src0);
3582
3583 GGML_ASSERT(dst->ne[0] == nc);
3584 GGML_ASSERT(ggml_nrows(dst) == nr);
3585
3586 const int32_t swapped = ggml_get_op_params_i32(dst, 1);
3587
3588 // rows per thread
3589 const int dr = (nr + nth - 1)/nth;
3590
3591 // row range for this thread
3592 const int ir0 = dr*ith;
3593 const int ir1 = MIN(ir0 + dr, nr);
3594
3595 for (int i1 = ir0; i1 < ir1; i1++) {
3596 ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
3597 ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
3598
3599 if (!src1) {
3600 src0_p += swapped ? nc : 0;
3601 src1_p += swapped ? 0 : nc;
3602 }
3603
3604 ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
3605
3606#ifndef NDEBUG
3607 for (int k = 0; k < nc; k++) {
3608 const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
3609 const float v = GGML_FP16_TO_FP32(x);
3610 GGML_UNUSED(v);
3611 assert(!isnan(v));
3612 assert(!isinf(v));
3613 }
3614#endif
3615 }
3616}
3617
3618static void ggml_compute_forward_geglu_quick(
3619 const ggml_compute_params * params,
3620 ggml_tensor * dst) {
3621
3622 const ggml_tensor * src0 = dst->src[0];
3623
3624 switch (src0->type) {
3625 case GGML_TYPE_F32:
3626 {
3627 ggml_compute_forward_geglu_quick_f32(params, dst);
3628 } break;
3629 case GGML_TYPE_F16:
3630 {
3631 ggml_compute_forward_geglu_quick_f16(params, dst);
3632 } break;
3633 default:
3634 {
3635 GGML_ABORT("fatal error");
3636 }
3637 }
3638}
3639
3640// ggml_compute_forward_norm
3641
3642static void ggml_compute_forward_norm_f32(
3643 const ggml_compute_params * params,
3644 ggml_tensor * dst) {
3645
3646 const ggml_tensor * src0 = dst->src[0];
3647
3648 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3649
3650 GGML_ASSERT(src0->nb[0] == sizeof(float));
3651
3652 const int ith = params->ith;
3653 const int nth = params->nth;
3654
3655 GGML_TENSOR_UNARY_OP_LOCALS
3656
3657 float eps;
3658 memcpy(&eps, dst->op_params, sizeof(float));
3659
3660 GGML_ASSERT(eps >= 0.0f);
3661
3662 for (int64_t i03 = 0; i03 < ne03; i03++) {
3663 for (int64_t i02 = 0; i02 < ne02; i02++) {
3664 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3665 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3666
3667 float sum = 0.0;
3668 ggml_vec_sum_f32(ne00, &sum, x);
3669 float mean = sum/ne00;
3670
3671 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3672 float variance = 0;
3673
3674#ifdef GGML_USE_ACCELERATE
3675 mean = -mean;
3676 vDSP_vsadd(x, 1, &mean, y, 1, ne00);
3677 vDSP_measqv(y, 1, &variance, ne00);
3678#else
3679 variance = ggml_vec_cvar_f32(ne00, y, x, mean);
3680#endif //GGML_USE_ACCELERATE
3681
3682 const float scale = 1.0f/sqrtf(variance + eps);
3683 ggml_vec_scale_f32(ne00, y, scale);
3684 }
3685 }
3686 }
3687}
3688
3689void ggml_compute_forward_norm(
3690 const ggml_compute_params * params,
3691 ggml_tensor * dst) {
3692
3693 const ggml_tensor * src0 = dst->src[0];
3694
3695 switch (src0->type) {
3696 case GGML_TYPE_F32:
3697 {
3698 ggml_compute_forward_norm_f32(params, dst);
3699 } break;
3700 default:
3701 {
3702 GGML_ABORT("fatal error");
3703 }
3704 }
3705}
3706
3707// ggml_compute_forward_group_rms_norm
3708
3709static void ggml_compute_forward_rms_norm_f32(
3710 const ggml_compute_params * params,
3711 ggml_tensor * dst) {
3712
3713 const ggml_tensor * src0 = dst->src[0];
3714
3715 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3716
3717 GGML_ASSERT(src0->nb[0] == sizeof(float));
3718
3719 const int ith = params->ith;
3720 const int nth = params->nth;
3721
3722 GGML_TENSOR_UNARY_OP_LOCALS
3723
3724 float eps;
3725 memcpy(&eps, dst->op_params, sizeof(float));
3726
3727 GGML_ASSERT(eps >= 0.0f);
3728
3729 // TODO: optimize
3730 for (int64_t i03 = 0; i03 < ne03; i03++) {
3731 for (int64_t i02 = 0; i02 < ne02; i02++) {
3732 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3733 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3734
3735 ggml_float sum = 0.0;
3736 for (int64_t i00 = 0; i00 < ne00; i00++) {
3737 sum += (ggml_float)(x[i00] * x[i00]);
3738 }
3739
3740 const float mean = sum/ne00;
3741
3742 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3743
3744 memcpy(y, x, ne00 * sizeof(float));
3745 // for (int i00 = 0; i00 < ne00; i00++) {
3746 // y[i00] = x[i00];
3747 // }
3748
3749 const float scale = 1.0f/sqrtf(mean + eps);
3750
3751 // if you hit this, likely you got an inf somewhere earlier
3752 assert(scale > 0.0f);
3753
3754 ggml_vec_scale_f32(ne00, y, scale);
3755 }
3756 }
3757 }
3758}
3759
3760void ggml_compute_forward_rms_norm(
3761 const ggml_compute_params * params,
3762 ggml_tensor * dst) {
3763
3764 const ggml_tensor * src0 = dst->src[0];
3765
3766 switch (src0->type) {
3767 case GGML_TYPE_F32:
3768 {
3769 ggml_compute_forward_rms_norm_f32(params, dst);
3770 } break;
3771 default:
3772 {
3773 GGML_ABORT("fatal error");
3774 }
3775 }
3776}
3777
3778static void ggml_compute_forward_rms_norm_back_f32(
3779 const ggml_compute_params * params,
3780 ggml_tensor * dst) {
3781
3782 const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
3783 const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
3784
3785 GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
3786
3787 GGML_ASSERT(src0->nb[0] == sizeof(float));
3788 GGML_ASSERT(src1->nb[0] == sizeof(float));
3789
3790 const int ith = params->ith;
3791 const int nth = params->nth;
3792
3793 GGML_TENSOR_BINARY_OP_LOCALS
3794
3795 float eps;
3796 memcpy(&eps, dst->op_params, sizeof(float));
3797
3798 // TODO: optimize
3799 for (int64_t i03 = 0; i03 < ne03; i03++) {
3800 for (int64_t i02 = 0; i02 < ne02; i02++) {
3801 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
3802 // src1 is same shape as src0 => same indices
3803 const int64_t i11 = i01;
3804 const int64_t i12 = i02;
3805 const int64_t i13 = i03;
3806
3807 const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
3808 const float * x = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
3809
3810 ggml_float sum_xx = 0.0;
3811 ggml_float sum_xdz = 0.0;
3812
3813 for (int64_t i00 = 0; i00 < ne00; i00++) {
3814 sum_xx += (ggml_float)(x[i00] * x[i00]);
3815 sum_xdz += (ggml_float)(x[i00] * dz[i00]);
3816 }
3817
3818 //const float mean = (float)(sum_xx)/ne00;
3819 const float mean_eps = (float)(sum_xx)/ne00 + eps;
3820 const float sum_eps = (float)(sum_xx) + eps*ne00;
3821 //const float mean_xdz = (float)(sum_xdz)/ne00;
3822 // we could cache rms from forward pass to improve performance.
3823 // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
3824 //const float rms = sqrtf(mean_eps);
3825 const float rrms = 1.0f / sqrtf(mean_eps);
3826 //const float scale = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
3827
3828 {
3829 // z = rms_norm(x)
3830 //
3831 // rms_norm(src1) =
3832 // scale(
3833 // src1,
3834 // div(
3835 // 1,
3836 // sqrt(
3837 // add(
3838 // scale(
3839 // sum(
3840 // sqr(
3841 // src1)),
3842 // (1.0/N)),
3843 // eps))));
3844
3845 // postorder:
3846 // ## op args grad
3847 // 00 param src1 grad[#00]
3848 // 01 const 1
3849 // 02 sqr (#00) grad[#02]
3850 // 03 sum (#02) grad[#03]
3851 // 04 const 1/N
3852 // 05 scale (#03, #04) grad[#05]
3853 // 06 const eps
3854 // 07 add (#05, #06) grad[#07]
3855 // 08 sqrt (#07) grad[#08]
3856 // 09 div (#01,#08) grad[#09]
3857 // 10 scale (#00,#09) grad[#10]
3858 //
3859 // backward pass, given grad[#10]
3860 // #10: scale
3861 // grad[#00] += scale(grad[#10],#09)
3862 // grad[#09] += sum(mul(grad[#10],#00))
3863 // #09: div
3864 // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
3865 // #08: sqrt
3866 // grad[#07] += mul(grad[#08], div(0.5, #08))
3867 // #07: add
3868 // grad[#05] += grad[#07]
3869 // #05: scale
3870 // grad[#03] += scale(grad[#05],#04)
3871 // #03: sum
3872 // grad[#02] += repeat(grad[#03], #02)
3873 // #02:
3874 // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
3875 //
3876 // substitute and simplify:
3877 // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3878 // grad[#02] = repeat(grad[#03], #02)
3879 // grad[#02] = repeat(scale(grad[#05],#04), #02)
3880 // grad[#02] = repeat(scale(grad[#07],#04), #02)
3881 // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
3882 // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
3883 // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
3884 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
3885 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
3886 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
3887 // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
3888 // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
3889 // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
3890 // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
3891 // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
3892 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3893 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
3894 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
3895 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
3896 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
3897 // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
3898 // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
3899 // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
3900 // a = b*c + d*e
3901 // a = b*c*f/f + d*e*f/f
3902 // a = (b*c*f + d*e*f)*(1/f)
3903 // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
3904 // a = (b + d*e/c)*c
3905 // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
3906 // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
3907 // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
3908 // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
3909 // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
3910 // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
3911 // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
3912 // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
3913 // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3914 // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3915 }
3916 // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
3917 // post-order:
3918 // dx := x
3919 // dx := scale(dx,-mean_xdz/mean_eps)
3920 // dx := add(dx, dz)
3921 // dx := scale(dx, rrms)
3922 float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
3923
3924 // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
3925 ggml_vec_cpy_f32 (ne00, dx, x);
3926 // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
3927 ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
3928 ggml_vec_acc_f32 (ne00, dx, dz);
3929 ggml_vec_scale_f32(ne00, dx, rrms);
3930 }
3931 }
3932 }
3933}
3934
3935void ggml_compute_forward_rms_norm_back(
3936 const ggml_compute_params * params,
3937 ggml_tensor * dst) {
3938
3939 const ggml_tensor * src0 = dst->src[0];
3940
3941 switch (src0->type) {
3942 case GGML_TYPE_F32:
3943 {
3944 ggml_compute_forward_rms_norm_back_f32(params, dst);
3945 } break;
3946 default:
3947 {
3948 GGML_ABORT("fatal error");
3949 }
3950 }
3951}
3952
3953// ggml_compute_forward_group_norm
3954
3955static void ggml_compute_forward_group_norm_f32(
3956 const ggml_compute_params * params,
3957 ggml_tensor * dst) {
3958
3959 const ggml_tensor * src0 = dst->src[0];
3960
3961 GGML_ASSERT(ggml_are_same_shape(src0, dst));
3962
3963 GGML_ASSERT(src0->nb[0] == sizeof(float));
3964
3965 const int ith = params->ith;
3966 const int nth = params->nth;
3967
3968 GGML_TENSOR_UNARY_OP_LOCALS
3969
3970 // TODO: optimize
3971
3972 float eps;
3973 memcpy(&eps, dst->op_params + 1, sizeof(float));
3974
3975 int n_channels = src0->ne[2];
3976 int n_groups = dst->op_params[0];
3977 int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
3978 for (int i = ith; i < n_groups; i += nth) {
3979 int start = i * n_channels_per_group;
3980 int end = start + n_channels_per_group;
3981 if (end > n_channels) {
3982 end = n_channels;
3983 }
3984 int step = end - start;
3985
3986 for (int64_t i03 = 0; i03 < ne03; i03++) {
3987 ggml_float sum = 0.0;
3988 for (int64_t i02 = start; i02 < end; i02++) {
3989 for (int64_t i01 = 0; i01 < ne01; i01++) {
3990 const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
3991
3992 ggml_float sumr = 0.0;
3993 for (int64_t i00 = 0; i00 < ne00; i00++) {
3994 sumr += (ggml_float)x[i00];
3995 }
3996 sum += sumr;
3997 }
3998 }
3999 const float mean = sum / (ne00 * ne01 * step);
4000
4001 ggml_float sum2 = 0.0;
4002 for (int64_t i02 = start; i02 < end; i02++) {
4003 for (int64_t i01 = 0; i01 < ne01; i01++) {
4004 const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
4005
4006 float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
4007
4008 ggml_float sumr = 0.0;
4009 for (int64_t i00 = 0; i00 < ne00; i00++) {
4010 float v = x[i00] - mean;
4011 y[i00] = v;
4012 sumr += (ggml_float)(v * v);
4013 }
4014 sum2 += sumr;
4015 }
4016 }
4017 const float variance = sum2 / (ne00 * ne01 * step);
4018 const float scale = 1.0f / sqrtf(variance + eps);
4019
4020 for (int64_t i02 = start; i02 < end; i02++) {
4021 for (int64_t i01 = 0; i01 < ne01; i01++) {
4022 float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
4023 ggml_vec_scale_f32(ne00, y, scale);
4024 }
4025 }
4026 }
4027 }
4028}
4029
4030void ggml_compute_forward_group_norm(
4031 const ggml_compute_params * params,
4032 ggml_tensor * dst) {
4033
4034 const ggml_tensor * src0 = dst->src[0];
4035
4036 switch (src0->type) {
4037 case GGML_TYPE_F32:
4038 {
4039 ggml_compute_forward_group_norm_f32(params, dst);
4040 } break;
4041 default:
4042 {
4043 GGML_ABORT("fatal error");
4044 }
4045 }
4046}
4047
4048// ggml_compute_forward_l2_norm
4049
4050static void ggml_compute_forward_l2_norm_f32(
4051 const ggml_compute_params * params,
4052 ggml_tensor * dst) {
4053
4054 const ggml_tensor * src0 = dst->src[0];
4055
4056 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4057
4058 GGML_ASSERT(src0->nb[0] == sizeof(float));
4059
4060 const int ith = params->ith;
4061 const int nth = params->nth;
4062
4063 GGML_TENSOR_UNARY_OP_LOCALS
4064
4065 float eps;
4066 memcpy(&eps, dst->op_params, sizeof(float));
4067
4068 GGML_ASSERT(eps >= 0.0f);
4069
4070 // TODO: optimize
4071 for (int64_t i03 = 0; i03 < ne03; i03++) {
4072 for (int64_t i02 = 0; i02 < ne02; i02++) {
4073 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
4074 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
4075
4076 ggml_float sum = 0.0;
4077 for (int64_t i00 = 0; i00 < ne00; i00++) {
4078 sum += (ggml_float)(x[i00] * x[i00]);
4079 }
4080
4081 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
4082
4083 memcpy(y, x, ne00 * sizeof(float));
4084
4085 const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
4086
4087 ggml_vec_scale_f32(ne00, y, scale);
4088 }
4089 }
4090 }
4091}
4092
4093void ggml_compute_forward_l2_norm(
4094 const ggml_compute_params * params,
4095 ggml_tensor * dst) {
4096
4097 const ggml_tensor * src0 = dst->src[0];
4098
4099 switch (src0->type) {
4100 case GGML_TYPE_F32:
4101 {
4102 ggml_compute_forward_l2_norm_f32(params, dst);
4103 } break;
4104 default:
4105 {
4106 GGML_ABORT("fatal error");
4107 }
4108 }
4109}
4110
4111// ggml_compute_forward_out_prod
4112
4113static void ggml_compute_forward_out_prod_f32(
4114 const ggml_compute_params * params,
4115 ggml_tensor * dst) {
4116
4117 const ggml_tensor * src0 = dst->src[0];
4118 const ggml_tensor * src1 = dst->src[1];
4119
4120 GGML_TENSOR_BINARY_OP_LOCALS
4121
4122 GGML_ASSERT(dst->type == GGML_TYPE_F32);
4123 GGML_ASSERT(src0->type == GGML_TYPE_F32);
4124 GGML_ASSERT(src1->type == GGML_TYPE_F32);
4125
4126 const int ith = params->ith;
4127 const int nth = params->nth;
4128
4129 GGML_ASSERT(ne0 == ne00);
4130 GGML_ASSERT(ne1 == ne10);
4131 GGML_ASSERT(ne2 == ne12);
4132 GGML_ASSERT(ne3 == ne13);
4133
4134 GGML_ASSERT(ne2 % ne02 == 0);
4135 GGML_ASSERT(ne3 % ne03 == 0);
4136
4137 // we don't support permuted src0 or src1
4138 GGML_ASSERT(nb00 == sizeof(float));
4139
4140 // dst cannot be transposed or permuted
4141 GGML_ASSERT(nb0 == sizeof(float));
4142 // GGML_ASSERT(nb0 <= nb1);
4143 // GGML_ASSERT(nb1 <= nb2);
4144 // GGML_ASSERT(nb2 <= nb3);
4145
4146 // nb01 >= nb00 - src0 is not transposed
4147 // compute by src0 rows
4148
4149 if (ith == 0) {
4150 ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4151 }
4152 ggml_barrier(params->threadpool);
4153
4154 // dst[:,:,:,:] = 0
4155 // for i2,i3:
4156 // for i1:
4157 // for i01:
4158 // for i0:
4159 // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4160
4161 // parallelize by last three dimensions
4162
4163 // total rows in dst
4164 const int64_t nr = ne1*ne2*ne3;
4165
4166 // rows per thread
4167 const int64_t dr = (nr + nth - 1)/nth;
4168
4169 // row range for this thread
4170 const int64_t ir0 = dr*ith;
4171 const int64_t ir1 = MIN(ir0 + dr, nr);
4172
4173 // block-tiling attempt
4174 const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
4175 const int64_t blck_1 = 16;
4176
4177 // dps == dst per src0, used for group query attention
4178 const int64_t dps2 = ne2 / ne02;
4179 const int64_t dps3 = ne3 / ne03;
4180
4181 for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
4182 const int64_t bir1 = MIN(bir + blck_1, ir1);
4183 for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
4184 const int64_t bne01 = MIN(bi01 + blck_0, ne01);
4185 for (int64_t ir = bir; ir < bir1; ++ir) {
4186 // dst indices
4187 const int64_t i3 = ir/(ne2*ne1);
4188 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4189 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4190
4191 const int64_t i02 = i2 / dps2;
4192 const int64_t i03 = i3 / dps3;
4193
4194 //const int64_t i10 = i1;
4195 const int64_t i12 = i2;
4196 const int64_t i13 = i3;
4197
4198#if GGML_VEC_MAD_UNROLL > 2
4199 const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
4200 for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
4201 const int64_t i11 = i01;
4202
4203 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4204 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4205 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4206
4207 ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
4208 }
4209 for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
4210 const int64_t i11 = i01;
4211
4212 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4213 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4214 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4215
4216 ggml_vec_mad_f32(ne0, d, s0, *s1);
4217 }
4218#else
4219 for (int64_t i01 = bi01; i01 < bne01; ++i01) {
4220 const int64_t i11 = i01;
4221
4222 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4223 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4224 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4225
4226 ggml_vec_mad_f32(ne0, d, s0, *s1);
4227 }
4228#endif
4229 }
4230 }
4231 }
4232}
4233
4234static void ggml_compute_forward_out_prod_q_f32(
4235 const ggml_compute_params * params,
4236 ggml_tensor * dst) {
4237
4238 const ggml_tensor * src0 = dst->src[0];
4239 const ggml_tensor * src1 = dst->src[1];
4240
4241 GGML_TENSOR_BINARY_OP_LOCALS;
4242
4243 const int ith = params->ith;
4244 const int nth = params->nth;
4245
4246 const ggml_type type = src0->type;
4247 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4248
4249 GGML_ASSERT(ne02 == ne12);
4250 GGML_ASSERT(ne03 == ne13);
4251 GGML_ASSERT(ne2 == ne12);
4252 GGML_ASSERT(ne3 == ne13);
4253
4254 // we don't support permuted src0 dim0
4255 GGML_ASSERT(nb00 == ggml_type_size(type));
4256
4257 // dst dim0 cannot be transposed or permuted
4258 GGML_ASSERT(nb0 == sizeof(float));
4259 // GGML_ASSERT(nb0 <= nb1);
4260 // GGML_ASSERT(nb1 <= nb2);
4261 // GGML_ASSERT(nb2 <= nb3);
4262
4263 GGML_ASSERT(ne0 == ne00);
4264 GGML_ASSERT(ne1 == ne10);
4265 GGML_ASSERT(ne2 == ne02);
4266 GGML_ASSERT(ne3 == ne03);
4267
4268 // nb01 >= nb00 - src0 is not transposed
4269 // compute by src0 rows
4270
4271 if (ith == 0) {
4272 ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
4273 }
4274 ggml_barrier(params->threadpool);
4275
4276 // parallelize by last three dimensions
4277
4278 // total rows in dst
4279 const int64_t nr = ne1*ne2*ne3;
4280
4281 // rows per thread
4282 const int64_t dr = (nr + nth - 1)/nth;
4283
4284 // row range for this thread
4285 const int64_t ir0 = dr*ith;
4286 const int64_t ir1 = MIN(ir0 + dr, nr);
4287
4288 // dst[:,:,:,:] = 0
4289 // for i2,i3:
4290 // for i1:
4291 // for i01:
4292 // for i0:
4293 // dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
4294
4295 float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
4296
4297 for (int64_t ir = ir0; ir < ir1; ++ir) {
4298 // dst indices
4299 const int64_t i3 = ir/(ne2*ne1);
4300 const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
4301 const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
4302
4303 const int64_t i02 = i2;
4304 const int64_t i03 = i3;
4305
4306 //const int64_t i10 = i1;
4307 const int64_t i12 = i2;
4308 const int64_t i13 = i3;
4309
4310 for (int64_t i01 = 0; i01 < ne01; ++i01) {
4311 const int64_t i11 = i01;
4312
4313 float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
4314 float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
4315 float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
4316
4317 dequantize_row_q(s0, wdata, ne0);
4318 ggml_vec_mad_f32(ne0, d, wdata, *s1);
4319 }
4320 }
4321}
4322
4323void ggml_compute_forward_out_prod(
4324 const ggml_compute_params * params,
4325 ggml_tensor * dst) {
4326
4327 const ggml_tensor * src0 = dst->src[0];
4328
4329 switch (src0->type) {
4330 case GGML_TYPE_Q4_0:
4331 case GGML_TYPE_Q4_1:
4332 case GGML_TYPE_Q5_0:
4333 case GGML_TYPE_Q5_1:
4334 case GGML_TYPE_Q8_0:
4335 case GGML_TYPE_MXFP4:
4336 case GGML_TYPE_Q2_K:
4337 case GGML_TYPE_Q3_K:
4338 case GGML_TYPE_Q4_K:
4339 case GGML_TYPE_Q5_K:
4340 case GGML_TYPE_Q6_K:
4341 case GGML_TYPE_TQ1_0:
4342 case GGML_TYPE_TQ2_0:
4343 case GGML_TYPE_IQ2_XXS:
4344 case GGML_TYPE_IQ2_XS:
4345 case GGML_TYPE_IQ3_XXS:
4346 case GGML_TYPE_IQ1_S:
4347 case GGML_TYPE_IQ1_M:
4348 case GGML_TYPE_IQ4_NL:
4349 case GGML_TYPE_IQ4_XS:
4350 case GGML_TYPE_IQ3_S:
4351 case GGML_TYPE_IQ2_S:
4352 {
4353 ggml_compute_forward_out_prod_q_f32(params, dst);
4354 } break;
4355 case GGML_TYPE_F16:
4356 {
4357 GGML_ABORT("fatal error"); // todo
4358 // ggml_compute_forward_out_prod_f16_f32(params, dst);
4359 }
4360 case GGML_TYPE_F32:
4361 {
4362 ggml_compute_forward_out_prod_f32(params, dst);
4363 } break;
4364 default:
4365 {
4366 GGML_ABORT("fatal error");
4367 }
4368 }
4369}
4370
4371// ggml_compute_forward_scale
4372
4373static void ggml_compute_forward_scale_f32(
4374 const ggml_compute_params * params,
4375 ggml_tensor * dst) {
4376
4377 const ggml_tensor * src0 = dst->src[0];
4378
4379 GGML_ASSERT(ggml_is_contiguous(src0));
4380 GGML_ASSERT(ggml_is_contiguous(dst));
4381 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4382
4383 float s; // scale factor
4384 float b; // bias
4385
4386 memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
4387 memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
4388
4389 const int ith = params->ith;
4390 const int nth = params->nth;
4391
4392 const int nc = src0->ne[0];
4393 const int nr = ggml_nrows(src0);
4394
4395 // rows per thread
4396 const int dr = (nr + nth - 1)/nth;
4397
4398 // row range for this thread
4399 const int ir0 = dr*ith;
4400 const int ir1 = MIN(ir0 + dr, nr);
4401
4402 const size_t nb01 = src0->nb[1];
4403
4404 const size_t nb1 = dst->nb[1];
4405
4406 if (b == 0.0f) {
4407 for (int i1 = ir0; i1 < ir1; i1++) {
4408 if (dst->data != src0->data) {
4409 // src0 is same shape as dst => same indices
4410 // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
4411 memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
4412 }
4413 ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
4414 }
4415 } else {
4416 for (int i1 = ir0; i1 < ir1; i1++) {
4417 ggml_vec_mad1_f32(nc,
4418 (float *) ((char *) dst->data + i1*nb1),
4419 (float *) ((char *) src0->data + i1*nb1),
4420 s, b);
4421 }
4422 }
4423}
4424
4425void ggml_compute_forward_scale(
4426 const ggml_compute_params * params,
4427 ggml_tensor * dst) {
4428
4429 const ggml_tensor * src0 = dst->src[0];
4430
4431 switch (src0->type) {
4432 case GGML_TYPE_F32:
4433 {
4434 ggml_compute_forward_scale_f32(params, dst);
4435 } break;
4436 default:
4437 {
4438 GGML_ABORT("fatal error");
4439 }
4440 }
4441}
4442
4443// ggml_compute_forward_set
4444
4445static void ggml_compute_forward_set_f32(
4446 const ggml_compute_params * params,
4447 ggml_tensor * dst) {
4448
4449 const ggml_tensor * src0 = dst->src[0];
4450 const ggml_tensor * src1 = dst->src[1];
4451
4452 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4453 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4454
4455 // view src0 and dst with these strides and data offset inbytes during set
4456 // nb0 is implicitly element_size because src0 and dst are contiguous
4457 size_t nb1 = ((int32_t *) dst->op_params)[0];
4458 size_t nb2 = ((int32_t *) dst->op_params)[1];
4459 size_t nb3 = ((int32_t *) dst->op_params)[2];
4460 size_t offset = ((int32_t *) dst->op_params)[3];
4461 bool inplace = (bool) ((int32_t *) dst->op_params)[4];
4462
4463 if (!inplace) {
4464 if (params->ith == 0) {
4465 // memcpy needs to be synchronized across threads to avoid race conditions.
4466 // => do it in INIT phase
4467 memcpy(
4468 ((char *) dst->data),
4469 ((char *) src0->data),
4470 ggml_nbytes(dst));
4471 }
4472 ggml_barrier(params->threadpool);
4473 }
4474
4475 const int ith = params->ith;
4476 const int nth = params->nth;
4477
4478 const int nr = ggml_nrows(src1);
4479 const int nc = src1->ne[0];
4480
4481 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4482 GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
4483
4484 // src0 and dst as viewed during set
4485 const size_t nb0 = ggml_element_size(src0);
4486
4487 const int im0 = (ne10 == 0 ? 0 : ne10-1);
4488 const int im1 = (ne11 == 0 ? 0 : ne11-1);
4489 const int im2 = (ne12 == 0 ? 0 : ne12-1);
4490 const int im3 = (ne13 == 0 ? 0 : ne13-1);
4491
4492 GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
4493
4494 GGML_ASSERT(nb10 == sizeof(float));
4495
4496 // rows per thread
4497 const int dr = (nr + nth - 1)/nth;
4498
4499 // row range for this thread
4500 const int ir0 = dr*ith;
4501 const int ir1 = MIN(ir0 + dr, nr);
4502
4503 for (int ir = ir0; ir < ir1; ++ir) {
4504 // src0 and dst are viewed with shape of src1 and offset
4505 // => same indices
4506 const int i3 = ir/(ne12*ne11);
4507 const int i2 = (ir - i3*ne12*ne11)/ne11;
4508 const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4509
4510 ggml_vec_cpy_f32(nc,
4511 (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
4512 (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4513 }
4514}
4515
4516static void ggml_compute_forward_set_i32(
4517 const ggml_compute_params * params,
4518 ggml_tensor * dst) {
4519
4520 const ggml_tensor * src0 = dst->src[0];
4521 const ggml_tensor * src1 = dst->src[1];
4522
4523 GGML_ASSERT(ggml_are_same_shape(src0, dst));
4524 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
4525
4526 // view src0 and dst with these strides and data offset inbytes during set
4527 // nb0 is implicitly element_size because src0 and dst are contiguous
4528 size_t nb1 = ((int32_t *) dst->op_params)[0];
4529 size_t nb2 = ((int32_t *) dst->op_params)[1];
4530 size_t nb3 = ((int32_t *) dst->op_params)[2];
4531 size_t offset = ((int32_t *) dst->op_params)[3];
4532 bool inplace = (bool) ((int32_t *) dst->op_params)[4];
4533
4534 if (!inplace) {
4535 if (params->ith == 0) {
4536 // memcpy needs to be synchronized across threads to avoid race conditions.
4537 // => do it in INIT phase
4538 memcpy(
4539 ((char *) dst->data),
4540 ((char *) src0->data),
4541 ggml_nbytes(dst));
4542 }
4543 ggml_barrier(params->threadpool);
4544 }
4545
4546 const int ith = params->ith;
4547 const int nth = params->nth;
4548
4549 const int nr = ggml_nrows(src1);
4550 const int nc = src1->ne[0];
4551
4552 GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
4553 GGML_TENSOR_LOCALS(size_t, nb1, src1, nb)
4554
4555 // src0 and dst as viewed during set
4556 const size_t nb0 = ggml_element_size(src0);
4557
4558 const int im0 = (ne10 == 0 ? 0 : ne10-1);
4559 const int im1 = (ne11 == 0 ? 0 : ne11-1);
4560 const int im2 = (ne12 == 0 ? 0 : ne12-1);
4561 const int im3 = (ne13 == 0 ? 0 : ne13-1);
4562
4563 GGML_ASSERT(offset + im0*nb0 + im1*nb1 + im2*nb2 + im3*nb3 <= ggml_nbytes(dst));
4564
4565 GGML_ASSERT(nb10 == sizeof(int32_t));
4566
4567 // rows per thread
4568 const int dr = (nr + nth - 1)/nth;
4569
4570 // row range for this thread
4571 const int ir0 = dr*ith;
4572 const int ir1 = MIN(ir0 + dr, nr);
4573
4574 for (int ir = ir0; ir < ir1; ++ir) {
4575 // src0 and dst are viewed with shape of src1 and offset
4576 // => same indices
4577 const int i3 = ir/(ne12*ne11);
4578 const int i2 = (ir - i3*ne12*ne11)/ne11;
4579 const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
4580
4581 ggml_vec_cpy_i32(nc,
4582 (int32_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + offset),
4583 (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
4584 }
4585}
4586
4587void ggml_compute_forward_set(
4588 const ggml_compute_params * params,
4589 ggml_tensor * dst) {
4590
4591 const ggml_tensor * src0 = dst->src[0];
4592
4593 switch (src0->type) {
4594 case GGML_TYPE_F32:
4595 {
4596 ggml_compute_forward_set_f32(params, dst);
4597 } break;
4598 case GGML_TYPE_I32:
4599 {
4600 ggml_compute_forward_set_i32(params, dst);
4601 } break;
4602 case GGML_TYPE_F16:
4603 case GGML_TYPE_BF16:
4604 case GGML_TYPE_Q4_0:
4605 case GGML_TYPE_Q4_1:
4606 case GGML_TYPE_Q5_0:
4607 case GGML_TYPE_Q5_1:
4608 case GGML_TYPE_Q8_0:
4609 case GGML_TYPE_Q8_1:
4610 case GGML_TYPE_MXFP4:
4611 case GGML_TYPE_Q2_K:
4612 case GGML_TYPE_Q3_K:
4613 case GGML_TYPE_Q4_K:
4614 case GGML_TYPE_Q5_K:
4615 case GGML_TYPE_Q6_K:
4616 case GGML_TYPE_TQ1_0:
4617 case GGML_TYPE_TQ2_0:
4618 case GGML_TYPE_IQ2_XXS:
4619 case GGML_TYPE_IQ2_XS:
4620 case GGML_TYPE_IQ3_XXS:
4621 case GGML_TYPE_IQ1_S:
4622 case GGML_TYPE_IQ1_M:
4623 case GGML_TYPE_IQ4_NL:
4624 case GGML_TYPE_IQ4_XS:
4625 case GGML_TYPE_IQ3_S:
4626 case GGML_TYPE_IQ2_S:
4627 default:
4628 {
4629 GGML_ABORT("fatal error");
4630 }
4631 }
4632}
4633
4634// ggml_compute_forward_cpy
4635
4636void ggml_compute_forward_cpy(
4637 const ggml_compute_params * params,
4638 ggml_tensor * dst) {
4639 ggml_compute_forward_dup(params, dst);
4640}
4641
4642// ggml_compute_forward_cont
4643
4644void ggml_compute_forward_cont(
4645 const ggml_compute_params * params,
4646 ggml_tensor * dst) {
4647 ggml_compute_forward_dup(params, dst);
4648}
4649
4650// ggml_compute_forward_get_rows
4651
4652static void ggml_compute_forward_get_rows_q(
4653 const ggml_compute_params * params,
4654 ggml_tensor * dst) {
4655
4656 const ggml_tensor * src0 = dst->src[0];
4657 const ggml_tensor * src1 = dst->src[1];
4658
4659 GGML_TENSOR_BINARY_OP_LOCALS
4660
4661 const int64_t nc = ne00;
4662 const int64_t nr = ggml_nelements(src1);
4663
4664 const ggml_type type = src0->type;
4665 ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
4666
4667 assert(ne0 == nc);
4668 assert(ne02 == ne11);
4669 assert(nb00 == ggml_type_size(type));
4670 assert(ggml_nrows(dst) == nr);
4671
4672 const int ith = params->ith;
4673 const int nth = params->nth;
4674
4675 // rows per thread
4676 const int dr = (nr + nth - 1)/nth;
4677
4678 // row range for this thread
4679 const int ir0 = dr*ith;
4680 const int ir1 = MIN(ir0 + dr, nr);
4681
4682 for (int64_t i = ir0; i < ir1; ++i) {
4683 const int64_t i12 = i/(ne11*ne10);
4684 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4685 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4686 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4687
4688 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4689
4690 dequantize_row_q(
4691 (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4692 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4693 }
4694}
4695
4696static void ggml_compute_forward_get_rows_f16(
4697 const ggml_compute_params * params,
4698 ggml_tensor * dst) {
4699
4700 const ggml_tensor * src0 = dst->src[0];
4701 const ggml_tensor * src1 = dst->src[1];
4702
4703 GGML_TENSOR_BINARY_OP_LOCALS
4704
4705 const int64_t nc = ne00;
4706 const int64_t nr = ggml_nelements(src1);
4707
4708 assert(ne0 == nc);
4709 assert(ne02 == ne11);
4710 assert(nb00 == sizeof(ggml_fp16_t));
4711 assert(ggml_nrows(dst) == nr);
4712
4713 const int ith = params->ith;
4714 const int nth = params->nth;
4715
4716 // rows per thread
4717 const int dr = (nr + nth - 1)/nth;
4718
4719 // row range for this thread
4720 const int ir0 = dr*ith;
4721 const int ir1 = MIN(ir0 + dr, nr);
4722
4723 for (int64_t i = ir0; i < ir1; ++i) {
4724 const int64_t i12 = i/(ne11*ne10);
4725 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4726 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4727 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4728
4729 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4730
4731 ggml_cpu_fp16_to_fp32(
4732 (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4733 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4734 }
4735}
4736
4737static void ggml_compute_forward_get_rows_bf16(
4738 const ggml_compute_params * params,
4739 ggml_tensor * dst) {
4740
4741 const ggml_tensor * src0 = dst->src[0];
4742 const ggml_tensor * src1 = dst->src[1];
4743
4744 GGML_TENSOR_BINARY_OP_LOCALS
4745
4746 const int64_t nc = ne00;
4747 const int64_t nr = ggml_nelements(src1);
4748
4749 assert(ne0 == nc);
4750 assert(ne02 == ne11);
4751 assert(nb00 == sizeof(ggml_bf16_t));
4752 assert(ggml_nrows(dst) == nr);
4753
4754 const int ith = params->ith;
4755 const int nth = params->nth;
4756
4757 // rows per thread
4758 const int dr = (nr + nth - 1)/nth;
4759
4760 // row range for this thread
4761 const int ir0 = dr*ith;
4762 const int ir1 = MIN(ir0 + dr, nr);
4763
4764 for (int64_t i = ir0; i < ir1; ++i) {
4765 const int64_t i12 = i/(ne11*ne10);
4766 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4767 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4768 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4769
4770 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4771
4772 ggml_cpu_bf16_to_fp32(
4773 (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
4774 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
4775 }
4776}
4777
4778static void ggml_compute_forward_get_rows_f32(
4779 const ggml_compute_params * params,
4780 ggml_tensor * dst) {
4781
4782 const ggml_tensor * src0 = dst->src[0];
4783 const ggml_tensor * src1 = dst->src[1];
4784
4785 GGML_TENSOR_BINARY_OP_LOCALS
4786
4787 const int64_t nc = ne00;
4788 const int64_t nr = ggml_nelements(src1);
4789
4790 assert(ne0 == nc);
4791 assert(ne02 == ne11);
4792 assert(nb00 == sizeof(float));
4793 assert(ggml_nrows(dst) == nr);
4794
4795 const int ith = params->ith;
4796 const int nth = params->nth;
4797
4798 // rows per thread
4799 const int dr = (nr + nth - 1)/nth;
4800
4801 // row range for this thread
4802 const int ir0 = dr*ith;
4803 const int ir1 = MIN(ir0 + dr, nr);
4804
4805 for (int64_t i = ir0; i < ir1; ++i) {
4806 const int64_t i12 = i/(ne11*ne10);
4807 const int64_t i11 = (i - i12*ne11*ne10)/ne10;
4808 const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
4809 const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4810
4811 GGML_ASSERT(i01 >= 0 && i01 < ne01);
4812
4813 ggml_vec_cpy_f32(nc,
4814 (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
4815 (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
4816 }
4817}
4818
4819void ggml_compute_forward_get_rows(
4820 const ggml_compute_params * params,
4821 ggml_tensor * dst) {
4822
4823 const ggml_tensor * src0 = dst->src[0];
4824
4825 switch (src0->type) {
4826 case GGML_TYPE_Q4_0:
4827 case GGML_TYPE_Q4_1:
4828 case GGML_TYPE_Q5_0:
4829 case GGML_TYPE_Q5_1:
4830 case GGML_TYPE_Q8_0:
4831 case GGML_TYPE_Q8_1:
4832 case GGML_TYPE_MXFP4:
4833 case GGML_TYPE_Q2_K:
4834 case GGML_TYPE_Q3_K:
4835 case GGML_TYPE_Q4_K:
4836 case GGML_TYPE_Q5_K:
4837 case GGML_TYPE_Q6_K:
4838 case GGML_TYPE_TQ1_0:
4839 case GGML_TYPE_TQ2_0:
4840 case GGML_TYPE_IQ2_XXS:
4841 case GGML_TYPE_IQ2_XS:
4842 case GGML_TYPE_IQ3_XXS:
4843 case GGML_TYPE_IQ1_S:
4844 case GGML_TYPE_IQ1_M:
4845 case GGML_TYPE_IQ4_NL:
4846 case GGML_TYPE_IQ4_XS:
4847 case GGML_TYPE_IQ3_S:
4848 case GGML_TYPE_IQ2_S:
4849 {
4850 ggml_compute_forward_get_rows_q(params, dst);
4851 } break;
4852 case GGML_TYPE_F16:
4853 {
4854 ggml_compute_forward_get_rows_f16(params, dst);
4855 } break;
4856 case GGML_TYPE_BF16:
4857 {
4858 ggml_compute_forward_get_rows_bf16(params, dst);
4859 } break;
4860 case GGML_TYPE_F32:
4861 case GGML_TYPE_I32:
4862 {
4863 ggml_compute_forward_get_rows_f32(params, dst);
4864 } break;
4865 default:
4866 {
4867 GGML_ABORT("fatal error");
4868 }
4869 }
4870
4871 //static bool first = true;
4872 //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
4873 //if (first) {
4874 // first = false;
4875 //} else {
4876 // for (int k = 0; k < dst->ne[1]; ++k) {
4877 // for (int j = 0; j < dst->ne[0]/16; ++j) {
4878 // for (int i = 0; i < 16; ++i) {
4879 // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
4880 // }
4881 // printf("\n");
4882 // }
4883 // printf("\n");
4884 // }
4885 // printf("\n");
4886 // exit(0);
4887 //}
4888}
4889
4890template<typename idx_t>
4891static void ggml_compute_forward_set_rows_f32(
4892 const ggml_compute_params * params,
4893 ggml_tensor * dst) {
4894
4895 const ggml_tensor * src0 = dst->src[0];
4896 const ggml_tensor * src1 = dst->src[1];
4897
4898 GGML_TENSOR_BINARY_OP_LOCALS
4899
4900 const int64_t nc = ne00;
4901 const int64_t nr = ne01;
4902
4903 assert(ne0 == nc);
4904 assert(ne2 == ne02);
4905 assert(ne3 == ne03);
4906 assert(src0->type == GGML_TYPE_F32);
4907 assert(ne02 % ne11 == 0);
4908 assert(ne03 % ne12 == 0);
4909
4910 const int ith = params->ith;
4911 const int nth = params->nth;
4912
4913 // rows per thread
4914 const int64_t dr = (nr + nth - 1)/nth;
4915
4916 // row range for this thread
4917 const int64_t ir0 = dr*ith;
4918 const int64_t ir1 = std::min(ir0 + dr, nr);
4919
4920 ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
4921
4922 for (int64_t i03 = 0; i03 < ne03; ++i03) {
4923 for (int64_t i02 = 0; i02 < ne02; ++i02) {
4924 for (int64_t i = ir0; i < ir1; ++i) {
4925 const int64_t i12 = i03%ne12;
4926 const int64_t i11 = i02%ne11;
4927 const int64_t i10 = i;
4928
4929 const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
4930
4931 GGML_ASSERT(i1 >= 0 && i1 < ne1);
4932
4933 from_float(
4934 (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
4935 ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc);
4936 }
4937 }
4938 }
4939}
4940
4941void ggml_compute_forward_set_rows(
4942 const ggml_compute_params * params,
4943 ggml_tensor * dst) {
4944
4945 const ggml_tensor * src0 = dst->src[0];
4946 const ggml_tensor * src1 = dst->src[1];
4947
4948 switch (src0->type) {
4949 case GGML_TYPE_F32:
4950 {
4951 if (src1->type == GGML_TYPE_I64) {
4952 ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
4953 } else if (src1->type == GGML_TYPE_I32) {
4954 ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
4955 } else {
4956 GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
4957 }
4958 } break;
4959 default:
4960 {
4961 GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
4962 }
4963 }
4964}
4965
4966// ggml_compute_forward_get_rows_back
4967
4968static void ggml_compute_forward_get_rows_back_f32_f16(
4969 const ggml_compute_params * params,
4970 ggml_tensor * dst) {
4971
4972 const ggml_tensor * src0 = dst->src[0];
4973 const ggml_tensor * src1 = dst->src[1];
4974
4975 if (params->ith != 0) {
4976 return;
4977 }
4978
4979 GGML_ASSERT(ggml_is_contiguous(dst));
4980
4981 // ggml_compute_forward_dup_same_cont(params, opt0, dst);
4982
4983 memset(dst->data, 0, ggml_nbytes(dst));
4984
4985 const int nc = src0->ne[0];
4986 const int nr = ggml_nelements(src1);
4987
4988 GGML_ASSERT( dst->ne[0] == nc);
4989 GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
4990
4991 for (int i = 0; i < nr; ++i) {
4992 const int r = ((int32_t *) src1->data)[i];
4993
4994 for (int j = 0; j < nc; ++j) {
4995 ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
4996 ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
4997 }
4998 }
4999}
5000
5001static void ggml_compute_forward_get_rows_back_f32(
5002 const ggml_compute_params * params,
5003 ggml_tensor * dst) {
5004
5005 const ggml_tensor * src0 = dst->src[0];
5006 const ggml_tensor * src1 = dst->src[1];
5007
5008 if (params->ith != 0) {
5009 return;
5010 }
5011
5012 GGML_ASSERT(ggml_is_contiguous(dst));
5013
5014 // ggml_compute_forward_dup_same_cont(params, opt0, dst);
5015
5016 memset(dst->data, 0, ggml_nbytes(dst));
5017
5018 const int nc = src0->ne[0];
5019 const int nr = ggml_nelements(src1);
5020
5021 GGML_ASSERT( dst->ne[0] == nc);
5022 GGML_ASSERT(src0->nb[0] == sizeof(float));
5023
5024 for (int i = 0; i < nr; ++i) {
5025 const int r = ((int32_t *) src1->data)[i];
5026
5027 ggml_vec_add_f32(nc,
5028 (float *) ((char *) dst->data + r*dst->nb[1]),
5029 (float *) ((char *) dst->data + r*dst->nb[1]),
5030 (float *) ((char *) src0->data + i*src0->nb[1]));
5031 }
5032}
5033
5034void ggml_compute_forward_get_rows_back(
5035 const ggml_compute_params * params,
5036 ggml_tensor * dst) {
5037
5038 const ggml_tensor * src0 = dst->src[0];
5039
5040 switch (src0->type) {
5041 case GGML_TYPE_F16:
5042 {
5043 ggml_compute_forward_get_rows_back_f32_f16(params, dst);
5044 } break;
5045 case GGML_TYPE_F32:
5046 {
5047 ggml_compute_forward_get_rows_back_f32(params, dst);
5048 } break;
5049 default:
5050 {
5051 GGML_ABORT("fatal error");
5052 }
5053 }
5054
5055 //static bool first = true;
5056 //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
5057 //if (first) {
5058 // first = false;
5059 //} else {
5060 // for (int k = 0; k < dst->ne[1]; ++k) {
5061 // for (int j = 0; j < dst->ne[0]/16; ++j) {
5062 // for (int i = 0; i < 16; ++i) {
5063 // printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
5064 // }
5065 // printf("\n");
5066 // }
5067 // printf("\n");
5068 // }
5069 // printf("\n");
5070 // exit(0);
5071 //}
5072}
5073
5074// ggml_compute_forward_diag
5075
5076static void ggml_compute_forward_diag_f32(
5077 const ggml_compute_params * params,
5078 ggml_tensor * dst) {
5079
5080 const ggml_tensor * src0 = dst->src[0];
5081
5082 if (params->ith != 0) {
5083 return;
5084 }
5085
5086 // TODO: handle transposed/permuted matrices
5087
5088 GGML_TENSOR_UNARY_OP_LOCALS
5089
5090 GGML_ASSERT(ne00 == ne0);
5091 GGML_ASSERT(ne00 == ne1);
5092 GGML_ASSERT(ne01 == 1);
5093 GGML_ASSERT(ne02 == ne2);
5094 GGML_ASSERT(ne03 == ne3);
5095
5096 GGML_ASSERT(nb00 == sizeof(float));
5097 GGML_ASSERT(nb0 == sizeof(float));
5098
5099 for (int i3 = 0; i3 < ne3; i3++) {
5100 for (int i2 = 0; i2 < ne2; i2++) {
5101 for (int i1 = 0; i1 < ne1; i1++) {
5102 float * d = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5103 float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
5104 for (int i0 = 0; i0 < i1; i0++) {
5105 d[i0] = 0;
5106 }
5107 d[i1] = s[i1];
5108 for (int i0 = i1+1; i0 < ne0; i0++) {
5109 d[i0] = 0;
5110 }
5111 }
5112 }
5113 }
5114}
5115
5116void ggml_compute_forward_diag(
5117 const ggml_compute_params * params,
5118 ggml_tensor * dst) {
5119
5120 const ggml_tensor * src0 = dst->src[0];
5121
5122 switch (src0->type) {
5123 case GGML_TYPE_F32:
5124 {
5125 ggml_compute_forward_diag_f32(params, dst);
5126 } break;
5127 default:
5128 {
5129 GGML_ABORT("fatal error");
5130 }
5131 }
5132}
5133
5134// ggml_compute_forward_diag_mask_inf
5135
5136static void ggml_compute_forward_diag_mask_f32(
5137 const ggml_compute_params * params,
5138 ggml_tensor * dst,
5139 const float value) {
5140
5141 const ggml_tensor * src0 = dst->src[0];
5142
5143 const int ith = params->ith;
5144 const int nth = params->nth;
5145
5146 const int n_past = ((int32_t *) dst->op_params)[0];
5147 const bool inplace = src0->data == dst->data;
5148
5149 GGML_ASSERT(n_past >= 0);
5150
5151 if (!inplace) {
5152 if (ith == 0) {
5153 // memcpy needs to be synchronized across threads to avoid race conditions.
5154 // => do it in INIT phase
5155 GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
5156 GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
5157 memcpy(
5158 ((char *) dst->data),
5159 ((char *) src0->data),
5160 ggml_nbytes(dst));
5161 }
5162 ggml_barrier(params->threadpool);
5163 }
5164
5165 // TODO: handle transposed/permuted matrices
5166
5167 const int n = ggml_nrows(src0);
5168 const int nc = src0->ne[0];
5169 const int nr = src0->ne[1];
5170 const int nz = n/nr;
5171
5172 GGML_ASSERT( dst->nb[0] == sizeof(float));
5173 GGML_ASSERT(src0->nb[0] == sizeof(float));
5174
5175 for (int k = 0; k < nz; k++) {
5176 for (int j = ith; j < nr; j += nth) {
5177 for (int i = n_past; i < nc; i++) {
5178 if (i > n_past + j) {
5179 *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
5180 }
5181 }
5182 }
5183 }
5184}
5185
5186void ggml_compute_forward_diag_mask_inf(
5187 const ggml_compute_params * params,
5188 ggml_tensor * dst) {
5189
5190 const ggml_tensor * src0 = dst->src[0];
5191
5192 switch (src0->type) {
5193 case GGML_TYPE_F32:
5194 {
5195 ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
5196 } break;
5197 default:
5198 {
5199 GGML_ABORT("fatal error");
5200 }
5201 }
5202}
5203
5204void ggml_compute_forward_diag_mask_zero(
5205 const ggml_compute_params * params,
5206 ggml_tensor * dst) {
5207
5208 const ggml_tensor * src0 = dst->src[0];
5209
5210 switch (src0->type) {
5211 case GGML_TYPE_F32:
5212 {
5213 ggml_compute_forward_diag_mask_f32(params, dst, 0);
5214 } break;
5215 default:
5216 {
5217 GGML_ABORT("fatal error");
5218 }
5219 }
5220}
5221
5222// ggml_compute_forward_soft_max
5223
5224static void ggml_compute_forward_soft_max_f32(
5225 const ggml_compute_params * params,
5226 ggml_tensor * dst) {
5227
5228 const ggml_tensor * src0 = dst->src[0];
5229 const ggml_tensor * src1 = dst->src[1];
5230 const ggml_tensor * src2 = dst->src[2];
5231
5232 assert(ggml_is_contiguous(dst));
5233 assert(ggml_are_same_shape(src0, dst));
5234
5235 float scale = 1.0f;
5236 float max_bias = 0.0f;
5237
5238 memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
5239 memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
5240
5241 const int ith = params->ith;
5242 const int nth = params->nth;
5243
5244 GGML_TENSOR_UNARY_OP_LOCALS
5245
5246 const int64_t nb11 = src1 ? src1->nb[1] : 1;
5247 const int64_t nb12 = src1 ? src1->nb[2] : 1;
5248 const int64_t nb13 = src1 ? src1->nb[3] : 1;
5249
5250 const int64_t ne12 = src1 ? src1->ne[2] : 1;
5251 const int64_t ne13 = src1 ? src1->ne[3] : 1;
5252
5253 // TODO: is this supposed to be ceil instead of floor?
5254 // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
5255 const uint32_t n_head = ne02;
5256 const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
5257
5258 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5259 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5260
5261 float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
5262
5263 const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
5264
5265 // sinks
5266 const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
5267
5268 for (int64_t i03 = 0; i03 < ne03; i03++) {
5269 for (int64_t i02 = 0; i02 < ne02; i02++) {
5270 for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
5271 const int64_t i11 = i01;
5272 const int64_t i12 = i02%ne12;
5273 const int64_t i13 = i03%ne13;
5274
5275 // ALiBi
5276 const uint32_t h = i02; // head
5277 const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
5278
5279 float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
5280 float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
5281
5282 // broadcast the mask across rows
5283 ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5284 float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
5285
5286 ggml_vec_cpy_f32 (ne00, wp, sp);
5287 ggml_vec_scale_f32(ne00, wp, scale);
5288 if (mp_f32) {
5289 if (use_f16) {
5290 for (int i = 0; i < ne00; ++i) {
5291 wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
5292 }
5293 } else {
5294 for (int i = 0; i < ne00; ++i) {
5295 wp[i] += slope*mp_f32[i];
5296 }
5297 }
5298 }
5299
5300#ifndef NDEBUG
5301 for (int i = 0; i < ne00; ++i) {
5302 //printf("p[%d] = %f\n", i, p[i]);
5303 assert(!isnan(wp[i]));
5304 }
5305#endif
5306
5307 float max = -INFINITY;
5308 ggml_vec_max_f32(ne00, &max, wp);
5309
5310 // if we have sinks, make a correction as if they were included in the softmax
5311 if (sk) {
5312 max = MAX(max, sk[i02]);
5313 }
5314
5315 ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
5316 assert(sum > 0.0);
5317
5318 if (sk) {
5319 sum += (ggml_float) expf(sk[i02] - max);
5320 }
5321
5322 sum = 1.0/sum;
5323 ggml_vec_scale_f32(ne00, dp, sum);
5324
5325#ifndef NDEBUG
5326 for (int i = 0; i < ne00; ++i) {
5327 assert(!isnan(dp[i]));
5328 assert(!isinf(dp[i]));
5329 }
5330#endif
5331 }
5332 }
5333 }
5334}
5335
5336void ggml_compute_forward_soft_max(
5337 const ggml_compute_params * params,
5338 ggml_tensor * dst) {
5339
5340 const ggml_tensor * src0 = dst->src[0];
5341
5342 switch (src0->type) {
5343 case GGML_TYPE_F32:
5344 {
5345 ggml_compute_forward_soft_max_f32(params, dst);
5346 } break;
5347 default:
5348 {
5349 GGML_ABORT("fatal error");
5350 }
5351 }
5352}
5353
5354
5355// ggml_compute_forward_soft_max_ext_back
5356
5357static void ggml_compute_forward_soft_max_ext_back_f32(
5358 const ggml_compute_params * params,
5359 ggml_tensor * dst) {
5360
5361 const ggml_tensor * src0 = dst->src[0];
5362 const ggml_tensor * src1 = dst->src[1];
5363
5364 GGML_ASSERT(ggml_is_contiguous(src0));
5365 GGML_ASSERT(ggml_is_contiguous(src1));
5366 GGML_ASSERT(ggml_is_contiguous(dst));
5367 GGML_ASSERT(ggml_are_same_shape(src0, dst));
5368 GGML_ASSERT(ggml_are_same_shape(src1, dst));
5369
5370 float scale = 1.0f;
5371 float max_bias = 0.0f;
5372
5373 memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
5374 memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
5375
5376 GGML_ASSERT(max_bias == 0.0f);
5377
5378 // TODO: handle transposed/permuted matrices
5379
5380 const int ith = params->ith;
5381 const int nth = params->nth;
5382
5383 const int nc = src0->ne[0];
5384 const int nr = ggml_nrows(src0);
5385
5386 // rows per thread
5387 const int dr = (nr + nth - 1)/nth;
5388
5389 // row range for this thread
5390 const int ir0 = dr*ith;
5391 const int ir1 = MIN(ir0 + dr, nr);
5392
5393 for (int i1 = ir0; i1 < ir1; i1++) {
5394 float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
5395 float *y = (float *)((char *) src1->data + i1*src1->nb[1]);
5396 float *dx = (float *)((char *) dst->data + i1*dst->nb[1]);
5397
5398#ifndef NDEBUG
5399 for (int i = 0; i < nc; ++i) {
5400 //printf("p[%d] = %f\n", i, p[i]);
5401 assert(!isnan(dy[i]));
5402 assert(!isnan(y[i]));
5403 }
5404#endif
5405 // Jii = yi - yi*yi
5406 // Jij = -yi*yj
5407 // J = diag(y)-y.T*y
5408 // dx = J * dy
5409 // dxk = sum_i(Jki * dyi)
5410 // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
5411 // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
5412 // dxk = sum_i(-yk*yi * dyi) + yk*dyk
5413 // dxk = -yk * sum_i(yi * dyi) + yk*dyk
5414 // dxk = -yk * dot(y, dy) + yk*dyk
5415 // dxk = yk * (- dot(y, dy) + dyk)
5416 // dxk = yk * (dyk - dot(y, dy))
5417 //
5418 // post-order:
5419 // dot_y_dy := dot(y, dy)
5420 // dx := dy
5421 // dx := dx - dot_y_dy
5422 // dx := dx * y
5423
5424 // linear runtime, no additional memory
5425 float dot_y_dy = 0;
5426 ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
5427 ggml_vec_cpy_f32 (nc, dx, dy);
5428 ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
5429 ggml_vec_mul_f32 (nc, dx, dx, y);
5430 ggml_vec_scale_f32(nc, dx, scale);
5431
5432#ifndef NDEBUG
5433 for (int i = 0; i < nc; ++i) {
5434 assert(!isnan(dx[i]));
5435 assert(!isinf(dx[i]));
5436 }
5437#endif
5438 }
5439}
5440
5441void ggml_compute_forward_soft_max_ext_back(
5442 const ggml_compute_params * params,
5443 ggml_tensor * dst) {
5444
5445 const ggml_tensor * src0 = dst->src[0];
5446
5447 switch (src0->type) {
5448 case GGML_TYPE_F32:
5449 {
5450 ggml_compute_forward_soft_max_ext_back_f32(params, dst);
5451 } break;
5452 default:
5453 {
5454 GGML_ABORT("fatal error");
5455 }
5456 }
5457}
5458
5459// ggml_compute_forward_clamp
5460
5461static void ggml_compute_forward_clamp_f32(
5462 const ggml_compute_params * params,
5463 ggml_tensor * dst) {
5464
5465 const ggml_tensor * src0 = dst->src[0];
5466
5467 float min;
5468 float max;
5469 memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5470 memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5471
5472 const int ith = params->ith;
5473 const int nth = params->nth;
5474
5475 const int n = ggml_nrows(src0);
5476 const int nc = src0->ne[0];
5477
5478 const size_t nb00 = src0->nb[0];
5479 const size_t nb01 = src0->nb[1];
5480
5481 const size_t nb0 = dst->nb[0];
5482 const size_t nb1 = dst->nb[1];
5483
5484 GGML_ASSERT( nb0 == sizeof(float));
5485 GGML_ASSERT(nb00 == sizeof(float));
5486
5487 for (int j = ith; j < n; j += nth) {
5488 float * dst_ptr = (float *) ((char *) dst->data + j*nb1);
5489 float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
5490
5491 for (int i = 0; i < nc; i++) {
5492 dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
5493 }
5494 }
5495}
5496
5497static void ggml_compute_forward_clamp_f16(
5498 const ggml_compute_params * params,
5499 ggml_tensor * dst) {
5500
5501 const ggml_tensor * src0 = dst->src[0];
5502
5503 float min;
5504 float max;
5505 memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
5506 memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
5507
5508 const int ith = params->ith;
5509 const int nth = params->nth;
5510
5511 const int n = ggml_nrows(src0);
5512 const int nc = src0->ne[0];
5513
5514 const size_t nb00 = src0->nb[0];
5515 const size_t nb01 = src0->nb[1];
5516
5517 const size_t nb0 = dst->nb[0];
5518 const size_t nb1 = dst->nb[1];
5519
5520 GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
5521 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5522
5523 for (int j = ith; j < n; j += nth) {
5524 ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
5525 ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
5526
5527 for (int i = 0; i < nc; i++) {
5528 float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
5529 dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
5530 }
5531 }
5532}
5533
5534void ggml_compute_forward_clamp(
5535 const ggml_compute_params * params,
5536 ggml_tensor * dst) {
5537
5538 const ggml_tensor * src0 = dst->src[0];
5539
5540 switch (src0->type) {
5541 case GGML_TYPE_F32:
5542 {
5543 ggml_compute_forward_clamp_f32(params, dst);
5544 } break;
5545 case GGML_TYPE_F16:
5546 {
5547 ggml_compute_forward_clamp_f16(params, dst);
5548 } break;
5549 case GGML_TYPE_BF16:
5550 case GGML_TYPE_Q4_0:
5551 case GGML_TYPE_Q4_1:
5552 case GGML_TYPE_Q5_0:
5553 case GGML_TYPE_Q5_1:
5554 case GGML_TYPE_Q8_0:
5555 case GGML_TYPE_Q8_1:
5556 case GGML_TYPE_MXFP4:
5557 case GGML_TYPE_Q2_K:
5558 case GGML_TYPE_Q3_K:
5559 case GGML_TYPE_Q4_K:
5560 case GGML_TYPE_Q5_K:
5561 case GGML_TYPE_Q6_K:
5562 case GGML_TYPE_TQ1_0:
5563 case GGML_TYPE_TQ2_0:
5564 case GGML_TYPE_IQ2_XXS:
5565 case GGML_TYPE_IQ2_XS:
5566 case GGML_TYPE_IQ3_XXS:
5567 case GGML_TYPE_IQ1_S:
5568 case GGML_TYPE_IQ1_M:
5569 case GGML_TYPE_IQ4_NL:
5570 case GGML_TYPE_IQ4_XS:
5571 case GGML_TYPE_IQ3_S:
5572 case GGML_TYPE_IQ2_S:
5573 case GGML_TYPE_Q8_K:
5574 case GGML_TYPE_I8:
5575 case GGML_TYPE_I16:
5576 case GGML_TYPE_I32:
5577 case GGML_TYPE_I64:
5578 case GGML_TYPE_F64:
5579 case GGML_TYPE_COUNT:
5580 {
5581 GGML_ABORT("fatal error");
5582 }
5583 }
5584}
5585
5586// ggml_compute_forward_rope
5587
5588static float rope_yarn_ramp(const float low, const float high, const int i0) {
5589 const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
5590 return 1 - MIN(1, MAX(0, y));
5591}
5592
5593// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
5594// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
5595static void rope_yarn(
5596 float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
5597 float * cos_theta, float * sin_theta) {
5598 // Get n-d rotational scaling corrected for extrapolation
5599 float theta_interp = freq_scale * theta_extrap;
5600 float theta = theta_interp;
5601 if (ext_factor != 0.0f) {
5602 float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
5603 theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
5604
5605 // Get n-d magnitude scaling corrected for interpolation
5606 mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
5607 }
5608 *cos_theta = cosf(theta) * mscale;
5609 *sin_theta = sinf(theta) * mscale;
5610}
5611
5612static void ggml_rope_cache_init(
5613 float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5614 float * cache, float sin_sign, float theta_scale) {
5615 // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5616 float theta = theta_base;
5617 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5618 const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5619 rope_yarn(
5620 theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5621 );
5622 cache[i0 + 1] *= sin_sign;
5623
5624 theta *= theta_scale;
5625 }
5626}
5627
5628static void ggml_mrope_cache_init(
5629 float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
5630 float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
5631 float * cache, float sin_sign, float theta_scale) {
5632 // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
5633 float theta_t = theta_base_t;
5634 float theta_h = theta_base_h;
5635 float theta_w = theta_base_w;
5636 float theta_e = theta_base_e; // extra position id for vision encoder
5637 int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
5638 int sec_w = sections[1] + sections[0];
5639 int sec_e = sections[2] + sec_w;
5640 GGML_ASSERT(sect_dims <= ne0);
5641
5642 for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
5643 const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
5644
5645 int sector = (i0 / 2) % sect_dims;
5646 if (indep_sects) {
5647 // compute theta independently for each dim sections
5648 // (i.e. reset corresponding theta when `i0` go from one section to another)
5649 if (sector == 0) {
5650 theta_t = theta_base_t;
5651 }
5652 else if (sector == sections[0]) {
5653 theta_h = theta_base_h;;
5654 }
5655 else if (sector == sec_w) {
5656 theta_w = theta_base_w;
5657 }
5658 else if (sector == sec_e) {
5659 theta_e = theta_base_e;
5660 }
5661 }
5662
5663 float theta = theta_t;
5664 if (is_imrope) { // qwen3vl apply interleaved mrope
5665 if (sector % 3 == 1 && sector < 3 * sections[1]) {
5666 theta = theta_h;
5667 } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
5668 theta = theta_w;
5669 } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
5670 theta = theta_t;
5671 } else {
5672 theta = theta_e;
5673 }
5674 } else {
5675 if (sector >= sections[0] && sector < sec_w) {
5676 theta = theta_h;
5677 }
5678 else if (sector >= sec_w && sector < sec_w + sections[2]) {
5679 theta = theta_w;
5680 }
5681 else if (sector >= sec_w + sections[2]) {
5682 theta = theta_e;
5683 }
5684 }
5685
5686 rope_yarn(
5687 theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
5688 );
5689 cache[i0 + 1] *= sin_sign;
5690
5691 theta_t *= theta_scale;
5692 theta_w *= theta_scale;
5693 theta_h *= theta_scale;
5694 theta_e *= theta_scale;
5695 }
5696}
5697
5698
5699template<typename T>
5700static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
5701 for (int64_t i0 = 0; i0 < n; i0 += 2) {
5702 const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
5703
5704 const float cos_theta = cache[i0 + 0];
5705 const float sin_theta = cache[i0 + 1];
5706
5707 const T * const src = src_data + ic;
5708 T * dst = dst_data + ic;
5709
5710 const float x0 = type_conversion_table<T>::to_f32(src[0]);
5711 const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
5712
5713 dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
5714 dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
5715 }
5716}
5717
5718template<typename T> //float or ggml_fp16_t
5719static void ggml_compute_forward_rope_flt(
5720 const ggml_compute_params * params,
5721 ggml_tensor * dst,
5722 const bool forward) {
5723
5724 const ggml_tensor * src0 = dst->src[0];
5725 const ggml_tensor * src1 = dst->src[1];
5726 const ggml_tensor * src2 = dst->src[2];
5727
5728 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
5729 GGML_ASSERT(src1->type == GGML_TYPE_I32);
5730
5731 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
5732 int sections[4];
5733
5734 //const int n_past = ((int32_t *) dst->op_params)[0];
5735 const int n_dims = ((int32_t *) dst->op_params)[1];
5736 const int mode = ((int32_t *) dst->op_params)[2];
5737 //const int n_ctx = ((int32_t *) dst->op_params)[3];
5738 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
5739
5740 memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
5741 memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
5742 memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
5743 memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
5744 memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
5745 memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
5746 memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
5747
5748 GGML_TENSOR_UNARY_OP_LOCALS
5749
5750 //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
5751 //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
5752
5753 GGML_ASSERT(nb0 == nb00);
5754 GGML_ASSERT(nb0 == sizeof(T));
5755
5756 const int ith = params->ith;
5757 const int nth = params->nth;
5758
5759 const int nr = ggml_nrows(dst);
5760
5761 GGML_ASSERT(n_dims <= ne0);
5762 GGML_ASSERT(n_dims % 2 == 0);
5763
5764 // rows per thread
5765 const int dr = (nr + nth - 1)/nth;
5766
5767 // row range for this thread
5768 const int ir0 = dr*ith;
5769 const int ir1 = MIN(ir0 + dr, nr);
5770
5771 // row index used to determine which thread to use
5772 int ir = 0;
5773
5774 const float theta_scale = powf(freq_base, -2.0f/n_dims);
5775
5776 float corr_dims[2];
5777 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
5778
5779 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
5780 const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
5781 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
5782
5783 if (mrope_used) {
5784 GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
5785 }
5786
5787 if (is_vision) {
5788 GGML_ASSERT(n_dims == ne0/2);
5789 }
5790
5791 const float * freq_factors = NULL;
5792 if (src2 != NULL) {
5793 GGML_ASSERT(src2->type == GGML_TYPE_F32);
5794 GGML_ASSERT(src2->ne[0] >= n_dims / 2);
5795 freq_factors = (const float *) src2->data;
5796 }
5797
5798 // backward process uses inverse rotation by cos and sin.
5799 // cos and sin build a rotation matrix, where the inverse is the transpose.
5800 // this essentially just switches the sign of sin.
5801 const float sin_sign = forward ? 1.0f : -1.0f;
5802
5803 const int32_t * pos = (const int32_t *) src1->data;
5804
5805 for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
5806 for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
5807
5808 float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
5809 if (!mrope_used) {
5810 const int64_t p = pos[i2];
5811 ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5812 }
5813 else {
5814 const int64_t p_t = pos[i2];
5815 const int64_t p_h = pos[i2 + ne2];
5816 const int64_t p_w = pos[i2 + ne2 * 2];
5817 const int64_t p_e = pos[i2 + ne2 * 3];
5818 ggml_mrope_cache_init(
5819 p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
5820 freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
5821 }
5822
5823 for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
5824 if (ir++ < ir0) continue;
5825 if (ir > ir1) break;
5826
5827 T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
5828 T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
5829
5830 switch (mode) {
5831 case GGML_ROPE_TYPE_NORMAL:
5832 rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
5833 break;
5834 case GGML_ROPE_TYPE_NEOX:
5835 case GGML_ROPE_TYPE_MROPE:
5836 case GGML_ROPE_TYPE_IMROPE:
5837 rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
5838 break;
5839 case GGML_ROPE_TYPE_VISION:
5840 rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
5841 break;
5842 default:
5843 GGML_ABORT("rope type not supported");
5844 }
5845
5846 if (!is_vision) {
5847 // fill the remain channels with data from src tensor
5848 for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
5849 const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
5850 T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
5851
5852 dst_data[0] = src[0];
5853 dst_data[1] = src[1];
5854 }
5855 }
5856 } //attn-heads
5857 }
5858 }
5859}
5860
5861void ggml_compute_forward_rope(
5862 const ggml_compute_params * params,
5863 ggml_tensor * dst) {
5864
5865 const ggml_tensor * src0 = dst->src[0];
5866
5867 switch (src0->type) {
5868 case GGML_TYPE_F16:
5869 {
5870 ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
5871 } break;
5872 case GGML_TYPE_F32:
5873 {
5874 ggml_compute_forward_rope_flt<float>(params, dst, true);
5875 } break;
5876 default:
5877 {
5878 GGML_ABORT("fatal error");
5879 }
5880 }
5881}
5882
5883// ggml_compute_forward_rope_back
5884
5885void ggml_compute_forward_rope_back(
5886 const ggml_compute_params * params,
5887 ggml_tensor * dst) {
5888
5889 const ggml_tensor * src0 = dst->src[0];
5890
5891 switch (src0->type) {
5892 case GGML_TYPE_F16:
5893 {
5894 ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
5895 } break;
5896 case GGML_TYPE_F32:
5897 {
5898 ggml_compute_forward_rope_flt<float>(params, dst, false);
5899 } break;
5900 default:
5901 {
5902 GGML_ABORT("fatal error");
5903 }
5904 }
5905}
5906
5907// ggml_compute_forward_conv_transpose_1d
5908
5909static void ggml_compute_forward_conv_transpose_1d_f16_f32(
5910 const ggml_compute_params * params,
5911 ggml_tensor * dst) {
5912
5913 const ggml_tensor * src0 = dst->src[0];
5914 const ggml_tensor * src1 = dst->src[1];
5915
5916 GGML_ASSERT(src0->type == GGML_TYPE_F16);
5917 GGML_ASSERT(src1->type == GGML_TYPE_F32);
5918 GGML_ASSERT( dst->type == GGML_TYPE_F32);
5919
5920 GGML_TENSOR_BINARY_OP_LOCALS
5921
5922 const int ith = params->ith;
5923 const int nth = params->nth;
5924
5925 const int nk = ne00*ne01*ne02;
5926
5927 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
5928 GGML_ASSERT(nb10 == sizeof(float));
5929
5930 if (ith == 0) {
5931 memset(params->wdata, 0, params->wsize);
5932
5933 // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
5934 {
5935 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5936
5937 for (int64_t i02 = 0; i02 < ne02; i02++) {
5938 for (int64_t i01 = 0; i01 < ne01; i01++) {
5939 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
5940 ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
5941 for (int64_t i00 = 0; i00 < ne00; i00++) {
5942 dst_data[i00*ne02 + i02] = src[i00];
5943 }
5944 }
5945 }
5946 }
5947
5948 // permute source data (src1) from (L x Cin) to (Cin x L)
5949 {
5950 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
5951 ggml_fp16_t * dst_data = wdata;
5952
5953 for (int64_t i11 = 0; i11 < ne11; i11++) {
5954 const float * const src = (float *)((char *) src1->data + i11*nb11);
5955 for (int64_t i10 = 0; i10 < ne10; i10++) {
5956 dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
5957 }
5958 }
5959 }
5960
5961 // need to zero dst since we are accumulating into it
5962 memset(dst->data, 0, ggml_nbytes(dst));
5963 }
5964 ggml_barrier(params->threadpool);
5965
5966 const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
5967
5968 // total rows in dst
5969 const int nr = ne1;
5970
5971 // rows per thread
5972 const int dr = (nr + nth - 1)/nth;
5973
5974 // row range for this thread
5975 const int ir0 = dr*ith;
5976 const int ir1 = MIN(ir0 + dr, nr);
5977
5978 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
5979 ggml_fp16_t * const wdata_src = wdata + nk;
5980
5981 for (int i1 = ir0; i1 < ir1; i1++) {
5982 float * dst_data = (float *)((char *) dst->data + i1*nb1);
5983 ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
5984 for (int i10 = 0; i10 < ne10; i10++) {
5985 const int i1n = i10*ne11;
5986 for (int i00 = 0; i00 < ne00; i00++) {
5987 float v = 0;
5988 ggml_vec_dot_f16(ne02, &v, 0,
5989 (ggml_fp16_t *) wdata_src + i1n, 0,
5990 (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
5991 dst_data[i10*s0 + i00] += v;
5992 }
5993 }
5994 }
5995}
5996
5997static void ggml_compute_forward_conv_transpose_1d_f32(
5998 const ggml_compute_params * params,
5999 ggml_tensor * dst) {
6000
6001 const ggml_tensor * src0 = dst->src[0];
6002 const ggml_tensor * src1 = dst->src[1];
6003
6004 GGML_ASSERT(src0->type == GGML_TYPE_F32);
6005 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6006 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6007
6008 GGML_TENSOR_BINARY_OP_LOCALS
6009
6010 const int ith = params->ith;
6011 const int nth = params->nth;
6012
6013 const int nk = ne00*ne01*ne02;
6014
6015 GGML_ASSERT(nb00 == sizeof(float));
6016 GGML_ASSERT(nb10 == sizeof(float));
6017
6018 if (ith == 0) {
6019 memset(params->wdata, 0, params->wsize);
6020
6021 // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
6022 {
6023 float * const wdata = (float *) params->wdata + 0;
6024
6025 for (int64_t i02 = 0; i02 < ne02; i02++) {
6026 for (int64_t i01 = 0; i01 < ne01; i01++) {
6027 const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
6028 float * dst_data = wdata + i01*ne00*ne02;
6029 for (int64_t i00 = 0; i00 < ne00; i00++) {
6030 dst_data[i00*ne02 + i02] = src[i00];
6031 }
6032 }
6033 }
6034 }
6035
6036 // prepare source data (src1)
6037 {
6038 float * const wdata = (float *) params->wdata + nk;
6039 float * dst_data = wdata;
6040
6041 for (int64_t i11 = 0; i11 < ne11; i11++) {
6042 const float * const src = (float *)((char *) src1->data + i11*nb11);
6043 for (int64_t i10 = 0; i10 < ne10; i10++) {
6044 dst_data[i10*ne11 + i11] = src[i10];
6045 }
6046 }
6047 }
6048
6049 // need to zero dst since we are accumulating into it
6050 memset(dst->data, 0, ggml_nbytes(dst));
6051 }
6052 ggml_barrier(params->threadpool);
6053
6054 const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
6055
6056 // total rows in dst
6057 const int nr = ne1;
6058
6059 // rows per thread
6060 const int dr = (nr + nth - 1)/nth;
6061
6062 // row range for this thread
6063 const int ir0 = dr*ith;
6064 const int ir1 = MIN(ir0 + dr, nr);
6065
6066 float * const wdata = (float *) params->wdata + 0;
6067 float * const wdata_src = wdata + nk;
6068
6069 for (int i1 = ir0; i1 < ir1; i1++) {
6070 float * dst_data = (float *)((char *) dst->data + i1*nb1);
6071 float * wdata_kernel = wdata + i1*ne02*ne00;
6072 for (int i10 = 0; i10 < ne10; i10++) {
6073 const int i1n = i10*ne11;
6074 for (int i00 = 0; i00 < ne00; i00++) {
6075 float v = 0;
6076 ggml_vec_dot_f32(ne02, &v, 0,
6077 wdata_src + i1n, 0,
6078 wdata_kernel + i00*ne02, 0, 1);
6079 dst_data[i10*s0 + i00] += v;
6080 }
6081 }
6082 }
6083}
6084
6085void ggml_compute_forward_conv_transpose_1d(
6086 const ggml_compute_params * params,
6087 ggml_tensor * dst) {
6088
6089 const ggml_tensor * src0 = dst->src[0];
6090
6091 switch (src0->type) {
6092 case GGML_TYPE_F16:
6093 {
6094 ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
6095 } break;
6096 case GGML_TYPE_F32:
6097 {
6098 ggml_compute_forward_conv_transpose_1d_f32(params, dst);
6099 } break;
6100 default:
6101 {
6102 GGML_ABORT("fatal error");
6103 }
6104 }
6105}
6106
6107// ggml_compute_forward_im2col_f32
6108// src0: kernel [OC, IC, KH, KW]
6109// src1: image [N, IC, IH, IW]
6110// dst: result [N, OH, OW, IC*KH*KW]
6111static void ggml_compute_forward_im2col_f32(
6112 const ggml_compute_params * params,
6113 ggml_tensor * dst) {
6114
6115 const ggml_tensor * src0 = dst->src[0];
6116 const ggml_tensor * src1 = dst->src[1];
6117
6118 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6119 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6120
6121 GGML_TENSOR_BINARY_OP_LOCALS;
6122
6123 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6124 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6125 const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6126 const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6127 const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6128 const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6129 const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6130
6131 const int ith = params->ith;
6132 const int nth = params->nth;
6133
6134 const int64_t N = is_2D ? ne13 : ne12;
6135 const int64_t IC = is_2D ? ne12 : ne11;
6136 const int64_t IH = is_2D ? ne11 : 1;
6137 const int64_t IW = ne10;
6138
6139 const int64_t KH = is_2D ? ne01 : 1;
6140 const int64_t KW = ne00;
6141
6142 const int64_t OH = is_2D ? ne2 : 1;
6143 const int64_t OW = ne1;
6144
6145 int ofs0 = is_2D ? nb13 : nb12;
6146 int ofs1 = is_2D ? nb12 : nb11;
6147
6148 GGML_ASSERT(nb10 == sizeof(float));
6149
6150 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6151 {
6152 float * const wdata = (float *) dst->data;
6153
6154 for (int64_t in = 0; in < N; in++) {
6155 for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6156 for (int64_t iow = 0; iow < OW; iow++) {
6157 for (int64_t iic = ith; iic < IC; iic += nth) {
6158
6159 // micro kernel
6160 float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6161 const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6162
6163 for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6164 for (int64_t ikw = 0; ikw < KW; ikw++) {
6165 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6166 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6167
6168 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6169 dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6170 } else {
6171 dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
6172 }
6173 }
6174 }
6175 }
6176 }
6177 }
6178 }
6179 }
6180}
6181
6182
6183// ggml_compute_forward_im2col_f16
6184// src0: kernel [OC, IC, KH, KW]
6185// src1: image [N, IC, IH, IW]
6186// dst: result [N, OH, OW, IC*KH*KW]
6187static void ggml_compute_forward_im2col_f16(
6188 const ggml_compute_params * params,
6189 ggml_tensor * dst) {
6190
6191 const ggml_tensor * src0 = dst->src[0];
6192 const ggml_tensor * src1 = dst->src[1];
6193
6194 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6195 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6196 GGML_ASSERT( dst->type == GGML_TYPE_F16);
6197
6198 GGML_TENSOR_BINARY_OP_LOCALS;
6199
6200 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6201 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6202 const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6203 const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6204 const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6205 const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6206 const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6207
6208 const int ith = params->ith;
6209 const int nth = params->nth;
6210
6211 const int64_t N = is_2D ? ne13 : ne12;
6212 const int64_t IC = is_2D ? ne12 : ne11;
6213 const int64_t IH = is_2D ? ne11 : 1;
6214 const int64_t IW = ne10;
6215
6216 const int64_t KH = is_2D ? ne01 : 1;
6217 const int64_t KW = ne00;
6218
6219 const int64_t OH = is_2D ? ne2 : 1;
6220 const int64_t OW = ne1;
6221
6222 int ofs0 = is_2D ? nb13 : nb12;
6223 int ofs1 = is_2D ? nb12 : nb11;
6224
6225 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6226 GGML_ASSERT(nb10 == sizeof(float));
6227
6228 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6229 {
6230 ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6231
6232 for (int64_t in = 0; in < N; in++) {
6233 for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
6234 for (int64_t iow = 0; iow < OW; iow++) {
6235 for (int64_t iic = ith; iic < IC; iic += nth) {
6236
6237 // micro kernel
6238 ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6239 const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
6240
6241 for (int64_t ikh = 0; ikh < KH; ikh++) { // 1
6242 for (int64_t ikw = 0; ikw < KW; ikw++) {
6243 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6244 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6245
6246 if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6247 dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
6248 } else {
6249 dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
6250 }
6251 }
6252 }
6253 }
6254 }
6255 }
6256 }
6257 }
6258}
6259
6260void ggml_compute_forward_im2col(
6261 const ggml_compute_params * params,
6262 ggml_tensor * dst) {
6263 switch (dst->type) {
6264 case GGML_TYPE_F16:
6265 {
6266 ggml_compute_forward_im2col_f16(params, dst);
6267 } break;
6268 case GGML_TYPE_F32:
6269 {
6270 ggml_compute_forward_im2col_f32(params, dst);
6271 } break;
6272 default:
6273 {
6274 GGML_ABORT("fatal error");
6275 }
6276 }
6277}
6278
6279// ggml_compute_forward_im2col_back_f32
6280
6281void ggml_compute_forward_im2col_back_f32(
6282 const ggml_compute_params * params,
6283 ggml_tensor * dst) {
6284
6285 const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
6286 const ggml_tensor * src1 = dst->src[1]; // convolution kernel
6287
6288 GGML_ASSERT(src0->type == GGML_TYPE_F32);
6289 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6290 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6291
6292 GGML_TENSOR_BINARY_OP_LOCALS;
6293
6294 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6295 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6296 const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
6297 const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
6298 const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
6299 const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
6300 const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
6301
6302 const int ith = params->ith;
6303 const int nth = params->nth;
6304
6305 const int64_t N = is_2D ? ne3 : ne2;
6306 const int64_t IC = is_2D ? ne2 : ne1;
6307 const int64_t IH = is_2D ? ne1 : 1;
6308 const int64_t IW = ne0;
6309
6310 const int64_t KH = is_2D ? ne11 : 1;
6311 const int64_t KW = ne10;
6312
6313 const int64_t OH = is_2D ? ne02 : 1;
6314 const int64_t OW = ne01;
6315
6316 int ofs0 = is_2D ? nb3 : nb2;
6317 int ofs1 = is_2D ? nb2 : nb1;
6318
6319 GGML_ASSERT(nb0 == sizeof(float));
6320
6321 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
6322 {
6323 float * const wdata = (float *) dst->data;
6324
6325 for (int64_t in = 0; in < N; in++) {
6326 for (int64_t iic = ith; iic < IC; iic += nth) {
6327 for (int64_t iih = 0; iih < IH; iih++) {
6328 for (int64_t iiw = 0; iiw < IW; iiw++) {
6329
6330 // micro kernel
6331 float grad = 0.0f;
6332 for (int64_t ikh = 0; ikh < KH; ikh++) {
6333 for (int64_t ikw = 0; ikw < KW; ikw++) {
6334 // For s0 > 1 some values were skipped over in the forward pass.
6335 // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
6336 const int64_t tmpw = (iiw + p0 - ikw*d0);
6337 if (tmpw % s0 != 0) {
6338 continue;
6339 }
6340 const int64_t iow = tmpw / s0;
6341
6342 // Equivalent logic as above except for s1.
6343 int64_t ioh;
6344 if (is_2D) {
6345 const int64_t tmph = iih + p1 - ikh*d1;
6346
6347 if (tmph % s1 != 0) {
6348 continue;
6349 }
6350
6351 ioh = tmph / s1;
6352 } else {
6353 ioh = 0;
6354 }
6355
6356 if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
6357 continue;
6358 }
6359
6360 const float * const grad_in = (const float *) src0->data
6361 + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
6362 grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
6363 }
6364 }
6365 float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
6366 dst_data[iih*IW + iiw] = grad;
6367 }
6368 }
6369 }
6370 }
6371 }
6372}
6373
6374
6375// ggml_compute_forward_im2col_3d_f16
6376// src0: kernel [OC*IC, KD, KH, KW]
6377// src1: image [N*IC, ID, IH, IW]
6378// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6379static void ggml_compute_forward_im2col_3d_f16(
6380 const ggml_compute_params * params,
6381 ggml_tensor * dst) {
6382
6383 const ggml_tensor * src0 = dst->src[0];
6384 const ggml_tensor * src1 = dst->src[1];
6385
6386 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6387 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6388 GGML_ASSERT( dst->type == GGML_TYPE_F16);
6389
6390 GGML_TENSOR_BINARY_OP_LOCALS;
6391
6392 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6393 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6394 const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6395 const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6396 const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6397 const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6398 const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6399 const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6400 const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6401 const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6402
6403
6404 const int ith = params->ith;
6405 const int nth = params->nth;
6406
6407 const int64_t N = ne13 / IC;
6408 const int64_t ID = ne12;
6409 const int64_t IH = ne11;
6410 const int64_t IW = ne10;
6411
6412 const int64_t OC = ne03 / IC;
6413 GGML_UNUSED(OC);
6414 const int64_t KD = ne02;
6415 const int64_t KH = ne01;
6416 const int64_t KW = ne00;
6417
6418 const int64_t OD = ne3 / N;
6419 const int64_t OH = ne2;
6420 const int64_t OW = ne1;
6421 const int64_t OH_OW = OH*OW;
6422 const int64_t KD_KH_KW = KD*KH*KW;
6423 const int64_t KH_KW = KH*KW;
6424 const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6425
6426 GGML_ASSERT(nb10 == sizeof(float));
6427
6428 // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6429 {
6430 ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
6431
6432 for (int64_t in = 0; in < N; in++) {
6433 for (int64_t iod = 0; iod < OD; iod++) {
6434 for (int64_t ioh = 0; ioh < OH; ioh++) {
6435 for (int64_t iow = 0; iow < OW; iow++) {
6436 for (int64_t iic = ith; iic < IC; iic += nth) {
6437
6438 // micro kernel
6439 ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6440 const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6441
6442 for (int64_t ikd = 0; ikd < KD; ikd++) {
6443 for (int64_t ikh = 0; ikh < KH; ikh++) {
6444 for (int64_t ikw = 0; ikw < KW; ikw++) {
6445 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6446 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6447 const int64_t iid = iod*s2 + ikd*d2 - p2;
6448
6449 if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
6450 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6451 } else {
6452 const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6453 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
6454 }
6455 }
6456 }
6457 }
6458 }
6459 }
6460 }
6461 }
6462 }
6463 }
6464}
6465
6466// ggml_compute_forward_im2col_3d_f32
6467// src0: kernel [OC*IC, KD, KH, KW]
6468// src1: image [N*IC, ID, IH, IW]
6469// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
6470static void ggml_compute_forward_im2col_3d_f32(
6471 const ggml_compute_params * params,
6472 ggml_tensor * dst) {
6473
6474 const ggml_tensor * src0 = dst->src[0];
6475 const ggml_tensor * src1 = dst->src[1];
6476
6477 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6478 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6479
6480 GGML_TENSOR_BINARY_OP_LOCALS;
6481
6482 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
6483 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
6484 const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
6485 const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
6486 const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
6487 const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
6488 const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
6489 const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
6490 const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
6491 const int32_t IC = ((const int32_t *)(dst->op_params))[9];
6492
6493
6494 const int ith = params->ith;
6495 const int nth = params->nth;
6496
6497 const int64_t N = ne13 / IC;
6498 const int64_t ID = ne12;
6499 const int64_t IH = ne11;
6500 const int64_t IW = ne10;
6501
6502 const int64_t OC = ne03 / IC;
6503 GGML_UNUSED(OC);
6504 const int64_t KD = ne02;
6505 const int64_t KH = ne01;
6506 const int64_t KW = ne00;
6507
6508 const int64_t OD = ne3 / N;
6509 const int64_t OH = ne2;
6510 const int64_t OW = ne1;
6511
6512 const int64_t OH_OW = OH*OW;
6513 const int64_t KD_KH_KW = KD*KH*KW;
6514 const int64_t KH_KW = KH*KW;
6515 const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
6516
6517 GGML_ASSERT(nb10 == sizeof(float));
6518
6519 // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
6520 {
6521 float * const wdata = (float *) dst->data;
6522
6523 for (int64_t in = 0; in < N; in++) {
6524 for (int64_t iod = 0; iod < OD; iod++) {
6525 for (int64_t ioh = 0; ioh < OH; ioh++) {
6526 for (int64_t iow = 0; iow < OW; iow++) {
6527 for (int64_t iic = ith; iic < IC; iic += nth) {
6528
6529 // micro kernel
6530 float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
6531 const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
6532
6533 for (int64_t ikd = 0; ikd < KD; ikd++) {
6534 for (int64_t ikh = 0; ikh < KH; ikh++) {
6535 for (int64_t ikw = 0; ikw < KW; ikw++) {
6536 const int64_t iiw = iow*s0 + ikw*d0 - p0;
6537 const int64_t iih = ioh*s1 + ikh*d1 - p1;
6538 const int64_t iid = iod*s2 + ikd*d2 - p2;
6539
6540 if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
6541 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
6542 } else {
6543 const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
6544 dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
6545 }
6546 }
6547 }
6548 }
6549 }
6550 }
6551 }
6552 }
6553 }
6554 }
6555}
6556
6557
6558void ggml_compute_forward_im2col_3d(
6559 const ggml_compute_params * params,
6560 ggml_tensor * dst) {
6561 switch (dst->type) {
6562 case GGML_TYPE_F16:
6563 {
6564 ggml_compute_forward_im2col_3d_f16(params, dst);
6565 } break;
6566 case GGML_TYPE_F32:
6567 {
6568 ggml_compute_forward_im2col_3d_f32(params, dst);
6569 } break;
6570 default:
6571 {
6572 GGML_ABORT("fatal error");
6573 }
6574 }
6575}
6576
6577static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
6578 void * a, void * b, float * c) {
6579 const ggml_type_traits * traits = ggml_get_type_traits(type);
6580 struct ggml_tensor src1 = {};
6581 src1.type = type;
6582 src1.ne[0] = k;
6583 src1.ne[1] = m;
6584 src1.ne[2] = 1;
6585 src1.ne[3] = 1;
6586 src1.nb[0] = traits->type_size;
6587 src1.nb[1] = k * traits->type_size;
6588 src1.nb[2] = src1.nb[1];
6589 src1.nb[3] = src1.nb[2];
6590 src1.data = a;
6591
6592 struct ggml_tensor src0 = {};
6593 src0.type = type;
6594 src0.ne[0] = k;
6595 src0.ne[1] = n;
6596 src0.ne[2] = 1;
6597 src0.ne[3] = 1;
6598 src0.nb[0] = traits->type_size;
6599 src0.nb[1] = k * traits->type_size;
6600 src0.nb[2] = src0.nb[1];
6601 src0.nb[3] = src0.nb[2];
6602 src0.data = b;
6603
6604 struct ggml_tensor dst = {};
6605 dst.ne[0] = n;
6606 dst.ne[1] = m;
6607 dst.ne[2] = 1;
6608 dst.ne[3] = 1;
6609 dst.nb[0] = sizeof(float);
6610 dst.nb[1] = n * sizeof(float);
6611 dst.nb[2] = dst.nb[1];
6612 dst.nb[3] = dst.nb[2];
6613 dst.data = c;
6614 dst.src[0] = &src0;
6615 dst.src[1] = &src1;
6616
6617 ggml_compute_forward_mul_mat(params, &dst);
6618}
6619
6620static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
6621 return (coord + size) % size; // adding size avoids negative number weirdness
6622}
6623
6624// ggml_compute_forward_conv_2d
6625
6626
6627static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
6628 const ggml_tensor * kernel, // [KW, KH, IC, OC]
6629 const ggml_tensor * src, // [W, H, C, N]
6630 ggml_tensor * dst, // [OW, OH, OC, N]
6631 ggml_type kernel_type) {
6632
6633 GGML_ASSERT(ggml_is_contiguous(kernel));
6634 GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6635 GGML_ASSERT(kernel->type == kernel_type);
6636
6637 const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6638
6639 const int32_t stride_x = dst->op_params[0];
6640 const int32_t stride_y = dst->op_params[1];
6641 const int32_t pad_x = dst->op_params[2];
6642 const int32_t pad_y = dst->op_params[3];
6643 const int32_t dilation_x = dst->op_params[4];
6644 const int32_t dilation_y = dst->op_params[5];
6645
6646 const int64_t c_in = src->ne[2];
6647 const int64_t c_out = kernel->ne[3];
6648 GGML_ASSERT(c_in == kernel->ne[2]);
6649
6650 const int64_t src_w = src->ne[0];
6651 const int64_t src_h = src->ne[1];
6652 const int64_t knl_w = kernel->ne[0];
6653 const int64_t knl_h = kernel->ne[1];
6654 const int64_t dst_w = dst->ne[0];
6655 const int64_t dst_h = dst->ne[1];
6656
6657 const float * src_data = (float *) src->data;
6658 void * knl_data = kernel->data;
6659 float * dst_data = (float *) dst->data;
6660
6661 const int64_t knl_n = knl_w * knl_h * c_in;
6662 const int64_t patch_total = dst->ne[3] * dst_w * dst_h;
6663
6664 const int64_t space_per_patch = knl_n * traits->type_size + c_out * sizeof(float);
6665 const int64_t batch_size = params->wsize / space_per_patch;
6666 const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6667 const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6668
6669 GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6670
6671 void * tmp = params->wdata;
6672
6673 for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6674
6675 const int64_t patch_start_batch = batch_i * patches_per_batch;
6676 const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch,
6677 patch_total);
6678 const int64_t patch_n = patch_end_batch - patch_start_batch;
6679
6680 const int64_t patch_per_thread = (patch_n + params->nth - 1) / params->nth;
6681 const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6682 const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6683
6684 //im2col for a patch
6685 for (int64_t p = patch_start; p < patch_end; ++p) {
6686 const int64_t batch_n = p / (dst_w * dst_h);
6687 const int64_t src_x = (p / dst_w) % dst_h;
6688 const int64_t src_y = p % dst_w;
6689
6690 const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
6691 char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
6692
6693 for (int64_t ic = 0; ic < c_in; ++ic) {
6694 for (int64_t ky = 0; ky < knl_h; ++ky) {
6695 for (int64_t kx = 0; kx < knl_w; ++kx) {
6696 const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
6697 const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
6698
6699 int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
6700
6701 float src_val;
6702 if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6703 src_val = 0.0f;
6704 } else {
6705 const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
6706 src_val = *src_ptr;
6707 }
6708
6709 char * element_ptr = dst_row + dst_idx * traits->type_size;
6710 if (kernel_type == GGML_TYPE_F32) {
6711 *(float *) element_ptr = src_val;
6712 } else if (kernel_type == GGML_TYPE_F16) {
6713 *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6714 }
6715 }
6716 }
6717 }
6718 } // patches handled by this thread
6719
6720 ggml_barrier(params->threadpool);
6721
6722 float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
6723
6724 GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
6725
6726 // GEMM: patches[patch_n, knl_n] ร kernel[knl_n, c_out] = output[patch_n, c_out]
6727 ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
6728
6729 ggml_barrier(params->threadpool);
6730
6731
6732 //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
6733 const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
6734 const int64_t permute_start = params->ith * permute_per_thread;
6735 const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
6736
6737 for (int64_t i = permute_start; i < permute_end; ++i) {
6738 const int64_t p = patch_start_batch + i;
6739 const int64_t batch_n = p / (dst_w * dst_h);
6740 const int64_t dst_y = (p / dst_w) % dst_h;
6741 const int64_t dst_x = p % dst_w;
6742
6743 for (int64_t oc = 0; oc < c_out; ++oc) {
6744 const float value = gemm_output[i * c_out + oc];
6745 float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
6746 *dst_ptr = value;
6747 }
6748 }
6749 }
6750}
6751
6752void ggml_compute_forward_conv_2d(
6753 const ggml_compute_params * params,
6754 ggml_tensor * dst) {
6755
6756 const ggml_tensor * src0 = dst->src[0];
6757 const ggml_tensor * src1 = dst->src[1];
6758
6759 ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
6760}
6761
6762// ggml_compute_forward_conv_3d
6763
6764static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
6765 const ggml_tensor * kernel,
6766 const ggml_tensor * src,
6767 ggml_tensor * dst,
6768 ggml_type kernel_type) {
6769
6770 GGML_ASSERT(ggml_is_contiguous(kernel));
6771 GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
6772 GGML_ASSERT(kernel->type == kernel_type);
6773
6774 const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
6775
6776 const int32_t s0 = dst->op_params[0];
6777 const int32_t s1 = dst->op_params[1];
6778 const int32_t s2 = dst->op_params[2];
6779 const int32_t p0 = dst->op_params[3];
6780 const int32_t p1 = dst->op_params[4];
6781 const int32_t p2 = dst->op_params[5];
6782 const int32_t d0 = dst->op_params[6];
6783 const int32_t d1 = dst->op_params[7];
6784 const int32_t d2 = dst->op_params[8];
6785 const int32_t c = dst->op_params[9];
6786 const int32_t n = dst->op_params[10];
6787 const int32_t oc = dst->op_params[11];
6788
6789 const int64_t src_w = src->ne[0];
6790 const int64_t src_h = src->ne[1];
6791 const int64_t src_d = src->ne[2];
6792 const int64_t knl_w = kernel->ne[0];
6793 const int64_t knl_h = kernel->ne[1];
6794 const int64_t knl_d = kernel->ne[2];
6795 const int64_t dst_w = dst->ne[0];
6796 const int64_t dst_h = dst->ne[1];
6797 const int64_t dst_d = dst->ne[2];
6798
6799 const float * src_data = (float *) src->data;
6800 void * knl_data = kernel->data;
6801 float * dst_data = (float *) dst->data;
6802
6803 const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
6804 const int64_t knl_n_total = knl_n_per_channel * c;
6805 const int64_t patch_total = n * dst_w * dst_h * dst_d;
6806
6807 const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
6808 const int64_t batch_size = params->wsize / space_per_patch;
6809 const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
6810 const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
6811
6812 GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
6813
6814 void * tmp = params->wdata;
6815
6816 for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
6817 const int64_t patch_start_batch = batch_i * patches_per_batch;
6818 const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
6819 const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
6820
6821 const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6822 const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
6823 const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
6824
6825 for (int64_t p = patch_start; p < patch_end; ++p) {
6826 const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6827 const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6828 const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6829 const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6830 const int64_t dst_y = p_in_depth / dst_w;
6831 const int64_t dst_x = p_in_depth % dst_w;
6832
6833 char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
6834
6835 for (int64_t ic = 0; ic < c; ++ic) {
6836 for (int64_t kz = 0; kz < knl_d; ++kz) {
6837 for (int64_t ky = 0; ky < knl_h; ++ky) {
6838 for (int64_t kx = 0; kx < knl_w; ++kx) {
6839 const int64_t sz = dst_z * s2 + kz * d2 - p2;
6840 const int64_t sy = dst_y * s1 + ky * d1 - p1;
6841 const int64_t sx = dst_x * s0 + kx * d0 - p0;
6842
6843 int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
6844
6845 float src_val;
6846 if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
6847 src_val = 0.0f;
6848 } else {
6849 const int64_t cn_idx = batch_idx * c + ic;
6850 const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
6851 src_val = *src_ptr;
6852 }
6853
6854 char * element_ptr = dst_row + dst_idx * traits->type_size;
6855 if (kernel_type == GGML_TYPE_F32) {
6856 *(float *)element_ptr = src_val;
6857 } else if (kernel_type == GGML_TYPE_F16) {
6858 *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
6859 }
6860 }
6861 }
6862 }
6863 }
6864 }
6865
6866 ggml_barrier(params->threadpool);
6867
6868 float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
6869 ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
6870
6871 ggml_barrier(params->threadpool);
6872
6873 const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
6874 const int64_t permute_start = params->ith * permute_per_thread;
6875 const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
6876
6877 for (int64_t i = permute_start; i < permute_end; ++i) {
6878 const int64_t p = patch_start_batch + i;
6879 const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
6880 const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
6881 const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
6882 const int64_t dst_z = p_in_batch / (dst_w * dst_h);
6883 const int64_t dst_y = p_in_depth / dst_w;
6884 const int64_t dst_x = p_in_depth % dst_w;
6885
6886 for (int64_t ioc = 0; ioc < oc; ++ioc) {
6887 const float value = gemm_output[i * oc + ioc];
6888 const int64_t ocn_idx = batch_idx * oc + ioc;
6889 float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
6890 *dst_ptr = value;
6891 }
6892 }
6893 }
6894}
6895
6896void ggml_compute_forward_conv_3d(
6897 const ggml_compute_params * params,
6898 ggml_tensor * dst) {
6899 const ggml_tensor * src0 = dst->src[0];
6900 const ggml_tensor * src1 = dst->src[1];
6901 ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
6902}
6903
6904// ggml_compute_forward_conv_transpose_2d
6905
6906void ggml_compute_forward_conv_transpose_2d(
6907 const ggml_compute_params * params,
6908 ggml_tensor * dst) {
6909
6910 const ggml_tensor * src0 = dst->src[0];
6911 const ggml_tensor * src1 = dst->src[1];
6912
6913 GGML_ASSERT(src0->type == GGML_TYPE_F16);
6914 GGML_ASSERT(src1->type == GGML_TYPE_F32);
6915 GGML_ASSERT( dst->type == GGML_TYPE_F32);
6916
6917 GGML_TENSOR_BINARY_OP_LOCALS
6918
6919 const int ith = params->ith;
6920 const int nth = params->nth;
6921
6922 const int nk = ne00*ne01*ne02*ne03;
6923
6924 GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
6925 GGML_ASSERT(nb10 == sizeof(float));
6926
6927 if (ith == 0) {
6928 memset(params->wdata, 0, params->wsize);
6929
6930 // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
6931 {
6932 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6933
6934 for (int64_t i03 = 0; i03 < ne03; i03++) {
6935 for (int64_t i02 = 0; i02 < ne02; i02++) {
6936 const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
6937 ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
6938 for (int64_t i01 = 0; i01 < ne01; i01++) {
6939 for (int64_t i00 = 0; i00 < ne00; i00++) {
6940 dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
6941 }
6942 }
6943 }
6944 }
6945 }
6946
6947 // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
6948 {
6949 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
6950 for (int i12 = 0; i12 < ne12; i12++) {
6951 for (int i11 = 0; i11 < ne11; i11++) {
6952 const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
6953 ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
6954 for (int i10 = 0; i10 < ne10; i10++) {
6955 dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
6956 }
6957 }
6958 }
6959 }
6960
6961 memset(dst->data, 0, ggml_nbytes(dst));
6962 }
6963 ggml_barrier(params->threadpool);
6964
6965 const int32_t stride = ggml_get_op_params_i32(dst, 0);
6966
6967 // total patches in dst
6968 const int np = ne2;
6969
6970 // patches per thread
6971 const int dp = (np + nth - 1)/nth;
6972
6973 // patch range for this thread
6974 const int ip0 = dp*ith;
6975 const int ip1 = MIN(ip0 + dp, np);
6976
6977 ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
6978 ggml_fp16_t * const wdata_src = wdata + nk;
6979
6980 for (int i2 = ip0; i2 < ip1; i2++) { // Cout
6981 float * dst_data = (float *)((char *) dst->data + i2*nb2);
6982 ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
6983 for (int i11 = 0; i11 < ne11; i11++) {
6984 for (int i10 = 0; i10 < ne10; i10++) {
6985 const int i1n = i11*ne10*ne12 + i10*ne12;
6986 for (int i01 = 0; i01 < ne01; i01++) {
6987 for (int i00 = 0; i00 < ne00; i00++) {
6988 float v = 0;
6989 ggml_vec_dot_f16(ne03, &v, 0,
6990 wdata_src + i1n, 0,
6991 wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
6992 dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
6993 }
6994 }
6995 }
6996 }
6997 }
6998}
6999
7000// ggml_compute_forward_conv_2d_dw
7001
7002struct ggml_conv_2d_dw_params {
7003 int64_t channels;
7004 int64_t batch;
7005 int64_t src_w;
7006 int64_t src_h;
7007 int64_t dst_w;
7008 int64_t dst_h;
7009 int64_t knl_w;
7010 int64_t knl_h;
7011 int stride_x;
7012 int stride_y;
7013 int pad_x;
7014 int pad_y;
7015 int dilation_x;
7016 int dilation_y;
7017};
7018
7019static void ggml_compute_forward_conv_2d_dw_cwhn(
7020 const ggml_compute_params * params,
7021 const ggml_tensor * src,
7022 const ggml_tensor * kernel,
7023 ggml_tensor * dst,
7024 const ggml_conv_2d_dw_params & p) {
7025
7026 const int64_t c = p.channels;
7027 const float * knl_data = (const float *)kernel->data;
7028
7029 const int64_t rows_total = p.dst_h * p.batch;
7030 const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
7031 const int64_t row_start = params->ith * rows_per_thread;
7032 const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
7033
7034#ifdef GGML_SIMD
7035 #if defined(__ARM_FEATURE_SVE)
7036 const int64_t pkg_size = svcntw();
7037 #else
7038 const int64_t pkg_size = GGML_F32_EPR;
7039 #endif
7040 const int64_t pkg_count = c / pkg_size;
7041 const int64_t c_pkg_end = pkg_count * pkg_size;
7042#else
7043 const int64_t c_pkg_end = 0;
7044#endif
7045
7046 for (int64_t row = row_start; row < row_end; ++row) {
7047 const int64_t dst_y = row % p.dst_h;
7048 const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
7049 for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7050 float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
7051 const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
7052 const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
7053
7054#ifdef GGML_SIMD
7055 // Vectorized loop
7056 for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
7057 GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
7058 for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7059 const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7060 if (src_y < 0 || src_y >= p.src_h) {
7061 continue;
7062 }
7063 for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7064 const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7065 if (src_x < 0 || src_x >= p.src_w) {
7066 continue;
7067 }
7068 GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
7069 GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
7070 sum = GGML_F32_VEC_FMA(sum, k, s);
7071 }
7072 }
7073 GGML_F32_VEC_STORE(dst_data + c_i, sum);
7074 }
7075#endif
7076 // Scalar loop
7077 for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
7078 float sum = 0.0f;
7079 for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7080 const int64_t src_y = src_y_base + knl_y * p.dilation_y;
7081 if (src_y < 0 || src_y >= p.src_h) {
7082 continue;
7083 }
7084 for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7085 const int64_t src_x = src_x_base + knl_x * p.dilation_x;
7086 if (src_x < 0 || src_x >= p.src_w) {
7087 continue;
7088 }
7089 sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
7090 * src_data[(src_y * p.src_w + src_x) * c + c_i];
7091 }
7092 }
7093 dst_data[c_i] = sum;
7094 }
7095 }
7096 }
7097}
7098
7099static void ggml_compute_forward_conv_2d_dw_whcn(
7100 const ggml_compute_params * params,
7101 const ggml_tensor * src,
7102 const ggml_tensor * kernel,
7103 ggml_tensor * dst,
7104 const ggml_conv_2d_dw_params & p) {
7105
7106 const int64_t n = p.channels * p.batch;
7107 const int64_t per_thread = (n + params->nth - 1) / params->nth;
7108 const int64_t start = params->ith * per_thread;
7109 const int64_t end = MIN(start + per_thread, n);
7110
7111 for (int64_t i = start; i < end; ++i) {
7112 const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
7113 const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
7114 float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
7115
7116 for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
7117 for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
7118
7119 float sum = 0.0f;
7120 for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
7121 const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
7122 if (src_y < 0 || src_y >= p.src_h) {
7123 continue;
7124 }
7125 for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
7126 const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
7127 if (src_x < 0 || src_x >= p.src_w) {
7128 continue;
7129 }
7130 sum += knl_data[knl_y * p.knl_w + knl_x]
7131 * src_data[src_y * p.src_w + src_x];
7132 }
7133 }
7134 dst_data[dst_y * p.dst_w + dst_x] = sum;
7135 }
7136 }
7137 }
7138}
7139
7140void ggml_compute_forward_conv_2d_dw(
7141 const ggml_compute_params * params,
7142 ggml_tensor * dst) {
7143
7144 const ggml_tensor * kernel = dst->src[0];
7145 const ggml_tensor * src = dst->src[1];
7146 ggml_conv_2d_dw_params p;
7147 p.channels = src->ne[2];
7148 p.batch = src->ne[3];
7149 p.src_w = src->ne[0];
7150 p.src_h = src->ne[1];
7151 p.dst_w = dst->ne[0];
7152 p.dst_h = dst->ne[1];
7153 p.knl_w = kernel->ne[0];
7154 p.knl_h = kernel->ne[1];
7155 p.stride_x = dst->op_params[0];
7156 p.stride_y = dst->op_params[1];
7157 p.pad_x = dst->op_params[2];
7158 p.pad_y = dst->op_params[3];
7159 p.dilation_x = dst->op_params[4];
7160 p.dilation_y = dst->op_params[5];
7161
7162 GGML_ASSERT(kernel->ne[3] == p.channels);
7163 GGML_ASSERT(dst->ne[3] == p.batch);
7164
7165 if (ggml_is_contiguous(src)) {
7166 ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
7167 } else if (ggml_is_contiguous_channels(src)) {
7168 // kernel should also have channels most contiguous in memory
7169 GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
7170 ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
7171 } else {
7172 GGML_ABORT("non-contiguous memory layout not supported");
7173 }
7174}
7175
7176// ggml_compute_forward_pool_1d_ksp
7177static void ggml_compute_forward_pool_1d_ksp(
7178 const ggml_compute_params * params,
7179 const ggml_op_pool op,
7180 const int k,
7181 const int s,
7182 const int p,
7183 ggml_tensor * dst) {
7184
7185 const ggml_tensor * src = dst->src[0];
7186
7187 assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7188
7189 if (params->ith != 0) {
7190 return;
7191 }
7192
7193 const int64_t IW = src->ne[0];
7194 const int64_t OW = dst->ne[0];
7195
7196 const int64_t nr = ggml_nrows(src);
7197
7198 for (int64_t ir = 0; ir < nr; ++ir) {
7199 const char * srow_bytes = (const char *) src->data + ir * src->nb[1];
7200 float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]);
7201
7202 for (int64_t ow = 0; ow < OW; ++ow) {
7203 float res = 0;
7204 switch (op) {
7205 case GGML_OP_POOL_AVG: res = 0.0f; break;
7206 case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7207 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7208 }
7209
7210 int count = 0;
7211 const int base = (int) ow * s - p;
7212
7213 for (int ki = 0; ki < k; ++ki) {
7214 const int j = base + ki;
7215 if (j < 0 || j >= (int) IW) {
7216 continue;
7217 }
7218
7219 float v;
7220 if (src->type == GGML_TYPE_F32) {
7221 v = ((const float *) srow_bytes)[j];
7222 } else {
7223 v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
7224 }
7225
7226 switch (op) {
7227 case GGML_OP_POOL_AVG: res += v; break;
7228 case GGML_OP_POOL_MAX: res = std::max(v, res); break;
7229 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7230 }
7231
7232 ++count;
7233 }
7234
7235 switch (op) {
7236 case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
7237 case GGML_OP_POOL_MAX: break;
7238 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7239 }
7240
7241 drow[ow] = res;
7242 }
7243 }
7244}
7245
7246// ggml_compute_forward_pool_1d
7247
7248void ggml_compute_forward_pool_1d(
7249 const ggml_compute_params * params,
7250 ggml_tensor * dst) {
7251
7252 const int32_t * opts = (const int32_t *)dst->op_params;
7253 ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7254 const int k0 = opts[1];
7255 const int s0 = opts[2];
7256 const int p0 = opts[3];
7257
7258 ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
7259}
7260
7261// ggml_compute_forward_pool_2d
7262
7263void ggml_compute_forward_pool_2d(
7264 const ggml_compute_params * params,
7265 ggml_tensor * dst) {
7266
7267 const ggml_tensor * src = dst->src[0];
7268
7269 assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
7270
7271 if (params->ith != 0) {
7272 return;
7273 }
7274
7275 const int32_t * opts = (const int32_t *)dst->op_params;
7276
7277 ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7278 const int k0 = opts[1];
7279 const int k1 = opts[2];
7280 const int s0 = opts[3];
7281 const int s1 = opts[4];
7282 const int p0 = opts[5];
7283 const int p1 = opts[6];
7284 const char * cdata = (const char*)src->data;
7285 const char * const data_end = cdata + ggml_nbytes(src);
7286
7287 const int64_t px = dst->ne[0];
7288 const int64_t py = dst->ne[1];
7289 const int64_t pa = px * py;
7290
7291 float * dplane = (float *)dst->data;
7292
7293 const int ka = k0 * k1;
7294 const int offset0 = -p0;
7295 const int offset1 = -p1;
7296
7297 while (cdata < data_end) {
7298 for (int oy = 0; oy < py; ++oy) {
7299 float * const drow = dplane + oy * px;
7300 float * const out = drow;
7301
7302 for (int ox = 0; ox < px; ++ox) {
7303 float res = 0;
7304 switch (op) {
7305 case GGML_OP_POOL_AVG: res = 0; break;
7306 case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
7307 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7308 }
7309
7310 const int ix = offset0 + ox * s0;
7311 const int iy = offset1 + oy * s1;
7312
7313 for (int ky = 0; ky < k1; ++ky) {
7314 if (iy + ky < 0 || iy + ky >= src->ne[1]) {
7315 continue;
7316 }
7317
7318 const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
7319 for (int kx = 0; kx < k0; ++kx) {
7320 int j = ix + kx;
7321 if (j < 0 || j >= src->ne[0]) {
7322 continue;
7323 }
7324
7325 const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
7326 switch (op) {
7327 case GGML_OP_POOL_AVG: res += srow_j; break;
7328 case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break;
7329 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7330 }
7331 }
7332 }
7333 switch (op) {
7334 case GGML_OP_POOL_AVG: res /= ka; break;
7335 case GGML_OP_POOL_MAX: break;
7336 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
7337 }
7338
7339 out[ox] = res;
7340 }
7341 }
7342
7343 cdata += src->nb[2];
7344 dplane += pa;
7345 }
7346}
7347
7348// ggml_compute_forward_pool_2d_back
7349
7350void ggml_compute_forward_pool_2d_back(
7351 const ggml_compute_params * params,
7352 ggml_tensor * dst) {
7353
7354 const ggml_tensor * src = dst->src[0];
7355 const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
7356
7357 assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
7358
7359 if (params->ith != 0) {
7360 return;
7361 }
7362
7363 const int32_t * opts = (const int32_t *)dst->op_params;
7364 ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
7365 const int k0 = opts[1];
7366 const int k1 = opts[2];
7367 const int s0 = opts[3];
7368 const int s1 = opts[4];
7369 const int p0 = opts[5];
7370 const int p1 = opts[6];
7371
7372 char * cdata = (char *) dst->data;
7373 const char * cdataf = (const char *) dstf->data;
7374 const char * const data_end = cdata + ggml_nbytes(dst);
7375
7376 GGML_ASSERT(params->ith == 0);
7377 memset(cdata, 0, ggml_nbytes(dst));
7378
7379 const int64_t px = src->ne[0];
7380 const int64_t py = src->ne[1];
7381 const int64_t pa = px * py;
7382
7383 const float * splane = (const float *) src->data;
7384
7385 const int ka = k0 * k1;
7386 const int offset0 = -p0;
7387 const int offset1 = -p1;
7388
7389 while (cdata < data_end) {
7390 for (int oy = 0; oy < py; ++oy) {
7391 const float * const srow = splane + oy * px;
7392 for (int ox = 0; ox < px; ++ox) {
7393 const float grad0 = srow[ox];
7394
7395 const int ix = offset0 + ox * s0;
7396 const int iy = offset1 + oy * s1;
7397
7398 if (op == GGML_OP_POOL_MAX) {
7399 float maxval = -FLT_MAX;
7400 int kxmax = -1;
7401 int kymax = -1;
7402
7403 for (int ky = 0; ky < k1; ++ky) {
7404 if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7405 continue;
7406 }
7407 const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
7408 for (int kx = 0; kx < k0; ++kx) {
7409 int j = ix + kx;
7410 if (j < 0 || j >= dst->ne[0]) {
7411 continue;
7412 }
7413
7414 const float val = dst->type == GGML_TYPE_F32 ?
7415 ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
7416 if (val <= maxval) {
7417 continue;
7418 }
7419
7420 maxval = val;
7421 kxmax = kx;
7422 kymax = ky;
7423 }
7424 }
7425
7426 if (kxmax == -1 || kymax == -1) {
7427 continue;
7428 }
7429
7430 void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
7431 const int j = ix + kxmax;
7432 if (dst->type == GGML_TYPE_F32) {
7433 ((float *) drow)[j] += grad0;
7434 } else {
7435 ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
7436 }
7437 } else if (op == GGML_OP_POOL_AVG) {
7438 const float grad = grad0 / ka;
7439
7440 for (int ky = 0; ky < k1; ++ky) {
7441 if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
7442 continue;
7443 }
7444 void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
7445 for (int kx = 0; kx < k0; ++kx) {
7446 int j = ix + kx;
7447 if (j < 0 || j >= dst->ne[0]) {
7448 continue;
7449 }
7450
7451 if (dst->type == GGML_TYPE_F32) {
7452 ((float *) drow)[j] += grad;
7453 } else {
7454 ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
7455 }
7456 }
7457 }
7458 } else {
7459 GGML_ASSERT(false);
7460 }
7461 }
7462 }
7463
7464 cdata += dst->nb[2];
7465 cdataf += dst->nb[2];
7466 splane += pa;
7467 }
7468}
7469
7470// ggml_compute_forward_upscale
7471
7472static void ggml_compute_forward_upscale_f32(
7473 const ggml_compute_params * params,
7474 ggml_tensor * dst) {
7475
7476 const ggml_tensor * src0 = dst->src[0];
7477
7478 GGML_ASSERT(src0->type == GGML_TYPE_F32);
7479
7480 const int ith = params->ith;
7481 const int nth = params->nth;
7482
7483 GGML_TENSOR_UNARY_OP_LOCALS
7484
7485 float sf0 = (float)ne0/src0->ne[0];
7486 float sf1 = (float)ne1/src0->ne[1];
7487 float sf2 = (float)ne2/src0->ne[2];
7488 float sf3 = (float)ne3/src0->ne[3];
7489 float pixel_offset = 0.5f;
7490
7491 const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
7492 const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
7493
7494 if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
7495 pixel_offset = 0.0f;
7496 sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
7497 sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
7498 }
7499
7500 if (mode == GGML_SCALE_MODE_NEAREST) {
7501 for (int64_t i3 = 0; i3 < ne3; i3++) {
7502 const int64_t i03 = i3 / sf3;
7503 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7504 const int64_t i02 = i2 / sf2;
7505 for (int64_t i1 = 0; i1 < ne1; i1++) {
7506 const int64_t i01 = i1 / sf1;
7507 for (int64_t i0 = 0; i0 < ne0; i0++) {
7508 const int64_t i00 = i0 / sf0;
7509
7510 const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7511 float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7512
7513 *y = *x;
7514 }
7515 }
7516 }
7517 }
7518 } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
7519 // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
7520 // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
7521 auto triangle_filter = [](float x) -> float {
7522 return std::max(1.0f - fabsf(x), 0.0f);
7523 };
7524
7525 // support and invscale, minimum 1 pixel for bilinear
7526 const float support1 = std::max(1.0f, 1.0f / sf1);
7527 const float invscale1 = 1.0f / support1;
7528 const float support0 = std::max(1.0f, 1.0f / sf0);
7529 const float invscale0 = 1.0f / support0;
7530
7531 for (int64_t i3 = 0; i3 < ne3; i3++) {
7532 const int64_t i03 = i3 / sf3;
7533 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7534 const int64_t i02 = i2 / sf2;
7535 for (int64_t i1 = 0; i1 < ne1; i1++) {
7536 const float y = ((float) i1 + pixel_offset) / sf1;
7537 for (int64_t i0 = 0; i0 < ne0; i0++) {
7538 const float x = ((float) i0 + pixel_offset) / sf0;
7539
7540 // the range of source pixels that contribute
7541 const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
7542 const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
7543 const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
7544 const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
7545
7546 // bilinear filter with antialiasing
7547 float val = 0.0f;
7548 float total_weight = 0.0f;
7549
7550 for (int64_t sy = y_min; sy < y_max; sy++) {
7551 const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
7552
7553 for (int64_t sx = x_min; sx < x_max; sx++) {
7554 const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
7555 const float weight = weight_x * weight_y;
7556
7557 if (weight <= 0.0f) {
7558 continue;
7559 }
7560
7561 const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
7562 val += pixel * weight;
7563 total_weight += weight;
7564 }
7565 }
7566
7567 if (total_weight > 0.0f) {
7568 val /= total_weight;
7569 }
7570
7571 float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7572 *dst_ptr = val;
7573 }
7574 }
7575 }
7576 }
7577 } else if (mode == GGML_SCALE_MODE_BILINEAR) {
7578 for (int64_t i3 = 0; i3 < ne3; i3++) {
7579 const int64_t i03 = i3 / sf3;
7580 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7581 const int64_t i02 = i2 / sf2;
7582 for (int64_t i1 = 0; i1 < ne1; i1++) {
7583 const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7584 int64_t y0 = (int64_t)floorf(y);
7585 int64_t y1 = y0 + 1;
7586
7587 y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
7588 y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
7589
7590 float dy = y - (float)y0;
7591 dy = std::max(0.0f, std::min(dy, 1.0f));
7592
7593 for (int64_t i0 = 0; i0 < ne0; i0++) {
7594 const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7595 int64_t x0 = (int64_t)floorf(x);
7596 int64_t x1 = x0 + 1;
7597
7598 x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
7599 x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
7600
7601 float dx = x - (float)x0;
7602 dx = std::max(0.0f, std::min(dx, 1.0f));
7603
7604 // fetch the four surrounding pixel values and interpolate
7605 const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7606 const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
7607 const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7608 const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
7609
7610 const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
7611
7612 float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7613 *y_dst = val;
7614 }
7615 }
7616 }
7617 }
7618 } else if (mode == GGML_SCALE_MODE_BICUBIC) {
7619 // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
7620 const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
7621 auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
7622 auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
7623 auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
7624 const float w0 = weight2(x + 1);
7625 const float w1 = weight1(x + 0);
7626 const float w2 = weight1(1 - x);
7627 const float w3 = weight2(2 - x);
7628 return p0*w0 + p1*w1 + p2*w2 + p3*w3;
7629 };
7630
7631 for (int64_t i3 = 0; i3 < ne3; i3++) {
7632 const int64_t i03 = i3 / sf3;
7633 for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
7634 const int64_t i02 = i2 / sf2;
7635 for (int64_t i1 = 0; i1 < ne1; i1++) {
7636 const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
7637 const int64_t y0 = (int64_t)floorf(y);
7638 const float dy = y - (float)y0;
7639
7640 for (int64_t i0 = 0; i0 < ne0; i0++) {
7641 const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
7642 const int64_t x0 = (int64_t)floorf(x);
7643 const float dx = x - (float)x0;
7644
7645 auto p = [=](int64_t x_off, int64_t y_off) -> float {
7646 int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
7647 int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
7648 return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
7649 };
7650
7651 const float val = bicubic(
7652 bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
7653 bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
7654 bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
7655 bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
7656
7657 float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
7658 *y_dst = val;
7659 }
7660 }
7661 }
7662 }
7663 } else {
7664 GGML_ABORT("unsupported upscale mode");
7665 }
7666}
7667
7668void ggml_compute_forward_upscale(
7669 const ggml_compute_params * params,
7670 ggml_tensor * dst) {
7671
7672 const ggml_tensor * src0 = dst->src[0];
7673
7674 switch (src0->type) {
7675 case GGML_TYPE_F32:
7676 {
7677 ggml_compute_forward_upscale_f32(params, dst);
7678 } break;
7679 default:
7680 {
7681 GGML_ABORT("fatal error");
7682 }
7683 }
7684}
7685
7686
7687// ggml_compute_forward_pad
7688
7689template<bool circular_t>
7690static void ggml_compute_forward_pad_f32(
7691 const ggml_compute_params * params,
7692 ggml_tensor * dst) {
7693
7694 const ggml_tensor * src0 = dst->src[0];
7695
7696 assert(dst->nb[0] == sizeof(float));
7697
7698 const int ith = params->ith;
7699 const int nth = params->nth;
7700
7701 GGML_TENSOR_UNARY_OP_LOCALS
7702
7703 float * dst_ptr = (float *) dst->data;
7704 const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
7705 const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
7706 const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
7707 const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
7708 const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
7709 const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
7710 const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
7711 const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
7712
7713 // TODO: optimize
7714
7715 for (int64_t i2 = 0; i2 < ne2; ++i2) {
7716 for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7717 for (int64_t i0 = 0; i0 < ne0; ++i0) {
7718 for (int64_t i3 = 0; i3 < ne3; ++i3) {
7719 // circular means wrap around on a torus, so x and y loop around
7720 if constexpr (circular_t) {
7721 const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7722 const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
7723 const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
7724 const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
7725 const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
7726
7727 const int64_t src_idx =
7728 src_i3*nb03 +
7729 src_i2*nb02 +
7730 src_i1*nb01 +
7731 src_i0*nb00;
7732
7733 const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7734 dst_ptr[dst_idx] = *src_ptr;
7735 } else {
7736 const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
7737 if ((i0 >= lp0 && i0 < ne0 - rp0) \
7738 && (i1 >= lp1 && i1 < ne1 - rp1) \
7739 && (i2 >= lp2 && i2 < ne2 - rp2) \
7740 && (i3 >= lp3 && i3 < ne3 - rp3)) {
7741 const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
7742 const float * src_ptr = (const float *)((char *) src0->data + src_idx);
7743 dst_ptr[dst_idx] = *src_ptr;
7744 } else {
7745 dst_ptr[dst_idx] = 0;
7746 }
7747 }
7748 }
7749 }
7750 }
7751 }
7752}
7753
7754
7755void ggml_compute_forward_pad(
7756 const ggml_compute_params * params,
7757 ggml_tensor * dst) {
7758 const ggml_tensor * src0 = dst->src[0];
7759 const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
7760 switch (src0->type) {
7761 case GGML_TYPE_F32:
7762 {
7763 if (circular) {
7764 ggml_compute_forward_pad_f32<true>(params, dst);
7765 } else {
7766 ggml_compute_forward_pad_f32<false>(params, dst);
7767 }
7768 } break;
7769 default:
7770 {
7771 GGML_ABORT("fatal error");
7772 }
7773 }
7774}
7775
7776// ggml_compute_forward_pad_reflect_1d
7777
7778void ggml_compute_forward_pad_reflect_1d(
7779 const ggml_compute_params * params,
7780 ggml_tensor * dst) {
7781
7782 const ggml_tensor * src0 = dst->src[0];
7783
7784 GGML_ASSERT(src0->type == GGML_TYPE_F32);
7785 GGML_ASSERT( dst->type == GGML_TYPE_F32);
7786
7787 const int ith = params->ith;
7788 const int nth = params->nth;
7789
7790 const int32_t * opts = (const int32_t *) dst->op_params;
7791 const int p0 = opts[0];
7792 const int p1 = opts[1];
7793
7794 GGML_TENSOR_UNARY_OP_LOCALS
7795
7796 for (int64_t i3 = 0; i3 < ne3; i3++) {
7797 for (int64_t i2 = 0; i2 < ne2; i2++) {
7798 for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
7799 float * left = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + p0*nb0);
7800 float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
7801
7802 ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
7803
7804 for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0]; }
7805 for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
7806 }
7807 }
7808 }
7809}
7810
7811// ggml_compute_forward_roll
7812
7813static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
7814 if (i < 0) {
7815 return i + ne;
7816 } else if (i >= ne) {
7817 return i - ne;
7818 }
7819 return i;
7820}
7821
7822static void ggml_compute_forward_roll_f32(
7823 const ggml_compute_params * params,
7824 ggml_tensor * dst) {
7825
7826 const ggml_tensor * src0 = dst->src[0];
7827 const float * src_data = (const float *) src0->data;
7828 float * dst_data = (float *) dst->data;
7829
7830 GGML_TENSOR_UNARY_OP_LOCALS
7831
7832 const int s0 = ggml_get_op_params_i32(dst, 0);
7833 const int s1 = ggml_get_op_params_i32(dst, 1);
7834 const int s2 = ggml_get_op_params_i32(dst, 2);
7835 const int s3 = ggml_get_op_params_i32(dst, 3);
7836
7837 const int64_t total = ne1 * ne2 * ne3;
7838 const int64_t per_thread = (total + params->nth) / params->nth;
7839 const int64_t start = params->ith * per_thread;
7840 const int64_t end = std::min(start + per_thread, total);
7841
7842 for (int64_t i = start; i < end; ++i) {
7843 const int64_t i1 = i % ne1;
7844 const int64_t i2 = (i / ne1) % ne2;
7845 const int64_t i3 = i / (ne2 * ne1);
7846 float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
7847
7848 const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
7849 const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
7850 const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
7851 const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
7852
7853 const int64_t s = ggml_wrap_index(-s0, ne00);
7854 const int64_t n = ne00 - s;
7855 ggml_vec_cpy_f32(n, dst_row, src_row + s);
7856 ggml_vec_cpy_f32(s, dst_row + n, src_row);
7857 }
7858}
7859
7860void ggml_compute_forward_roll(
7861 const ggml_compute_params * params,
7862 ggml_tensor * dst) {
7863
7864 const ggml_tensor * src0 = dst->src[0];
7865
7866 switch (src0->type) {
7867 case GGML_TYPE_F32:
7868 {
7869 ggml_compute_forward_roll_f32(params, dst);
7870 } break;
7871 default:
7872 {
7873 GGML_ABORT("fatal error");
7874 }
7875 }
7876}
7877
7878// ggml_compute_forward_arange
7879
7880static void ggml_compute_forward_arange_f32(
7881 const ggml_compute_params * params,
7882 ggml_tensor * dst) {
7883
7884 GGML_ASSERT(dst->nb[0] == sizeof(float));
7885
7886 const int ith = params->ith;
7887 const int nth = params->nth;
7888
7889 const float start = ggml_get_op_params_f32(dst, 0);
7890 const float stop = ggml_get_op_params_f32(dst, 1);
7891 const float step = ggml_get_op_params_f32(dst, 2);
7892
7893 const int64_t steps = (int64_t) ceilf((stop - start) / step);
7894
7895 GGML_ASSERT(ggml_nelements(dst) == steps);
7896
7897 for (int64_t i = ith; i < steps; i+= nth) {
7898 float value = start + step * i;
7899 ((float *)dst->data)[i] = value;
7900 }
7901}
7902
7903void ggml_compute_forward_arange(
7904 const ggml_compute_params * params,
7905 ggml_tensor * dst) {
7906 switch (dst->type) {
7907 case GGML_TYPE_F32:
7908 {
7909 ggml_compute_forward_arange_f32(params, dst);
7910 } break;
7911 default:
7912 {
7913 GGML_ABORT("fatal error");
7914 }
7915 }
7916}
7917
7918static void ggml_compute_forward_timestep_embedding_f32(
7919 const ggml_compute_params * params,
7920 ggml_tensor * dst) {
7921
7922 const ggml_tensor * src0 = dst->src[0];
7923
7924 GGML_ASSERT(src0->nb[0] == sizeof(float));
7925
7926 const int ith = params->ith;
7927 const int nth = params->nth;
7928
7929 GGML_TENSOR_UNARY_OP_LOCALS
7930
7931 const int dim = ggml_get_op_params_i32(dst, 0);
7932 const int max_period = ggml_get_op_params_i32(dst, 1);
7933
7934 int half = dim / 2;
7935
7936 for (int64_t i = 0; i < ne00; i++) {
7937 float * embed_data = (float *)((char *) dst->data + i*nb1);
7938 for (int64_t j = ith; j < half; j += nth) {
7939 float timestep = ((float *)src0->data)[i];
7940 float freq = (float)expf(-logf(max_period) * j / half);
7941 float arg = timestep * freq;
7942 embed_data[j] = cosf(arg);
7943 embed_data[j + half] = sinf(arg);
7944 }
7945 if (dim % 2 != 0 && ith == 0) {
7946 embed_data[2 * half] = 0.f;
7947 }
7948 }
7949}
7950
7951void ggml_compute_forward_timestep_embedding(
7952 const ggml_compute_params * params,
7953 ggml_tensor * dst) {
7954
7955 const ggml_tensor * src0 = dst->src[0];
7956
7957 switch (src0->type) {
7958 case GGML_TYPE_F32:
7959 {
7960 ggml_compute_forward_timestep_embedding_f32(params, dst);
7961 } break;
7962 default:
7963 {
7964 GGML_ABORT("fatal error");
7965 }
7966 }
7967}
7968
7969// ggml_compute_forward_argsort
7970
7971template<enum ggml_sort_order order>
7972struct cmp_argsort {
7973 const float * data;
7974 bool operator()(int32_t a, int32_t b) const {
7975 if constexpr (order == GGML_SORT_ORDER_ASC) {
7976 return data[a] < data[b];
7977 } else {
7978 return data[a] > data[b];
7979 }
7980 }
7981};
7982
7983static void ggml_compute_forward_argsort_f32(
7984 const ggml_compute_params * params,
7985 ggml_tensor * dst) {
7986
7987 const ggml_tensor * src0 = dst->src[0];
7988
7989 GGML_TENSOR_UNARY_OP_LOCALS
7990
7991 GGML_ASSERT(nb0 == sizeof(float));
7992
7993 const int ith = params->ith;
7994 const int nth = params->nth;
7995
7996 const int64_t nr = ggml_nrows(src0);
7997
7998 ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
7999
8000 for (int64_t i = ith; i < nr; i += nth) {
8001 const float * src_data = (float *)((char *) src0->data + i*nb01);
8002
8003 int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8004
8005 for (int64_t j = 0; j < ne0; j++) {
8006 dst_data[j] = j;
8007 }
8008
8009 switch (order) {
8010 case GGML_SORT_ORDER_ASC:
8011 std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
8012 break;
8013
8014 case GGML_SORT_ORDER_DESC:
8015 std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
8016 break;
8017
8018 default:
8019 GGML_ABORT("invalid sort order");
8020 }
8021 }
8022}
8023
8024void ggml_compute_forward_argsort(
8025 const ggml_compute_params * params,
8026 ggml_tensor * dst) {
8027
8028 const ggml_tensor * src0 = dst->src[0];
8029
8030 switch (src0->type) {
8031 case GGML_TYPE_F32:
8032 {
8033 ggml_compute_forward_argsort_f32(params, dst);
8034 } break;
8035 default:
8036 {
8037 GGML_ABORT("fatal error");
8038 }
8039 }
8040}
8041
8042// ggml_compute_forward_top_k
8043
8044struct cmp_top_k {
8045 const float * data;
8046 bool operator()(int32_t a, int32_t b) const {
8047 return data[a] > data[b];
8048 }
8049};
8050
8051static void ggml_compute_forward_top_k_f32(
8052 const ggml_compute_params * params,
8053 ggml_tensor * dst) {
8054
8055 const ggml_tensor * src0 = dst->src[0];
8056
8057 GGML_TENSOR_UNARY_OP_LOCALS
8058
8059 GGML_ASSERT(nb0 == sizeof(float));
8060
8061 const int ith = params->ith;
8062 const int nth = params->nth;
8063
8064 const int64_t nr = ggml_nrows(src0);
8065
8066 const int top_k = ne0;
8067
8068 int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
8069
8070 for (int64_t i = ith; i < nr; i += nth) {
8071 const float * src_data = (float *)((char *) src0->data + i*nb01);
8072
8073 for (int64_t j = 0; j < ne00; j++) {
8074 tmp[j] = j;
8075 }
8076
8077 std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
8078
8079 int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
8080
8081 std::copy(tmp, tmp + top_k, dst_data);
8082
8083 // emphasize that the order is not important
8084 if (top_k > 1) {
8085 std::swap(dst_data[0], dst_data[1]);
8086 }
8087 }
8088}
8089
8090void ggml_compute_forward_top_k(
8091 const ggml_compute_params * params,
8092 ggml_tensor * dst) {
8093
8094 const ggml_tensor * src0 = dst->src[0];
8095
8096 switch (src0->type) {
8097 case GGML_TYPE_F32:
8098 {
8099 ggml_compute_forward_top_k_f32(params, dst);
8100 } break;
8101 default:
8102 {
8103 GGML_ABORT("fatal error");
8104 }
8105 }
8106}
8107
8108static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8109 const ggml_compute_params * params,
8110 ggml_tensor * dst,
8111 int ir0, int ir1,
8112 int64_t ic_start, int64_t ic_end,
8113 float * partials, int64_t partial_stride) {
8114
8115 const bool write_partials = (partials != nullptr);
8116 const ggml_tensor * q = dst->src[0];
8117 const ggml_tensor * k = dst->src[1];
8118 const ggml_tensor * v = dst->src[2];
8119 const ggml_tensor * mask = dst->src[3];
8120 const ggml_tensor * sinks = dst->src[4];
8121
8122 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8123 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8124 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8125 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8126 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8127 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8128 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8129 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8130
8131 const int64_t DK = nek0;
8132 const int64_t DV = nev0;
8133 const int64_t N = neq1;
8134
8135 GGML_ASSERT(ne0 == DV);
8136 GGML_ASSERT(ne2 == N);
8137
8138 // input tensor rows must be contiguous
8139 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8140 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8141 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8142
8143 GGML_ASSERT(neq0 == DK);
8144 GGML_ASSERT(nek0 == DK);
8145 GGML_ASSERT(nev0 == DV);
8146
8147 GGML_ASSERT(neq1 == N);
8148
8149 // dst cannot be transposed or permuted
8150 GGML_ASSERT(nb0 == sizeof(float));
8151 GGML_ASSERT(nb0 <= nb1);
8152 GGML_ASSERT(nb1 <= nb2);
8153 GGML_ASSERT(nb2 <= nb3);
8154
8155 // broadcast factors
8156 const int64_t rk2 = neq2/nek2;
8157 const int64_t rk3 = neq3/nek3;
8158
8159 const int64_t rv2 = neq2/nev2;
8160 const int64_t rv3 = neq3/nev3;
8161
8162 // parallelize by q rows using ggml_vec_dot_f32
8163
8164 float scale = 1.0f;
8165 float max_bias = 0.0f;
8166 float logit_softcap = 0.0f;
8167
8168 memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
8169 memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
8170 memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8171
8172 if (logit_softcap != 0) {
8173 scale /= logit_softcap;
8174 }
8175
8176 const uint32_t n_head = neq2;
8177 const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8178
8179 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8180 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8181
8182 ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
8183 ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
8184 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot;
8185 ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float;
8186
8187 GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
8188 GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
8189
8190 int ith = params->ith;
8191
8192 for (int ir = ir0; ir < ir1; ++ir) {
8193 // q indices
8194 const int iq3 = ir/(neq2*neq1);
8195 const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8196 const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8197
8198 const uint32_t h = iq2; // head index
8199 const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8200
8201 float S = 0.0f; // sum
8202 float M = -INFINITY; // maximum KQ value
8203
8204 float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
8205 float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
8206 ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
8207 ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
8208
8209 if (v->type == GGML_TYPE_F16) {
8210 memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
8211 } else {
8212 memset(VKQ32, 0, DV*sizeof(float));
8213 }
8214
8215 const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
8216
8217 // k indices
8218 const int ik3 = iq3 / rk3;
8219 const int ik2 = iq2 / rk2;
8220
8221 // v indices
8222 const int iv3 = iq3 / rv3;
8223 const int iv2 = iq2 / rv2;
8224
8225 const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
8226 q_to_vec_dot(pq, Q_q, DK);
8227
8228 // online softmax / attention
8229 // loop over n_kv and n_head_kv
8230 // ref: https://arxiv.org/pdf/2112.05682.pdf
8231
8232 for (int64_t ic = ic_start; ic < ic_end; ++ic) {
8233 const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
8234 if (mv == -INFINITY) {
8235 continue;
8236 }
8237
8238 float s; // KQ value
8239
8240 const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
8241 kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
8242
8243 s = s*scale; // scale KQ value
8244
8245 if (logit_softcap != 0.0f) {
8246 s = logit_softcap*tanhf(s);
8247 }
8248
8249 s += mv; // apply mask
8250
8251 const float Mold = M;
8252
8253 float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
8254 float vs = 1.0f; // post-softmax KQ value, expf(s - M)
8255
8256 const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
8257
8258 if (v->type == GGML_TYPE_F16) {
8259 if (s > M) {
8260 // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8261 M = s;
8262 ms = expf(Mold - M);
8263
8264 // V = V*expf(Mold - M)
8265 ggml_vec_scale_f16(DV, VKQ16, ms);
8266 } else {
8267 // no new maximum, ms == 1.0f, vs != 1.0f
8268 vs = expf(s - M);
8269 }
8270
8271 // V += v*expf(s - M)
8272 ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
8273 } else {
8274 if (s > M) {
8275 // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
8276 M = s;
8277 ms = expf(Mold - M);
8278
8279 // V = V*expf(Mold - M)
8280 ggml_vec_scale_f32(DV, VKQ32, ms);
8281 } else {
8282 // no new maximum, ms == 1.0f, vs != 1.0f
8283 vs = expf(s - M);
8284 }
8285
8286 // V += v*expf(s - M)
8287 if (v_to_float) {
8288 v_to_float(v_data, V32, DV);
8289 ggml_vec_mad_f32(DV, VKQ32, V32, vs);
8290 } else {
8291 // V is F32
8292 ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
8293 }
8294 }
8295
8296 S = S*ms + vs; // scale and increment sum with partial sum
8297 }
8298
8299 if (v->type == GGML_TYPE_F16) {
8300 for (int64_t d = 0; d < DV; ++d) {
8301 VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
8302 }
8303 }
8304
8305 // sinks - apply only on the first kv-chunk
8306 if (sinks && ic_start == 0) {
8307 const float s = ((float *)((char *) sinks->data))[h];
8308
8309 float ms = 1.0f;
8310 float vs = 1.0f;
8311
8312 if (s > M) {
8313 ms = expf(M - s);
8314 M = s;
8315 ggml_vec_scale_f32(DV, VKQ32, ms);
8316 } else {
8317 vs = expf(s - M);
8318 }
8319
8320 S = S*ms + vs;
8321 }
8322
8323 if (write_partials) {
8324 // Write M, S, VKQ to partials for later reduction
8325 // partials layout: [M, S, VKQ[DV]] per query head
8326 float * partial = partials + ir * partial_stride;
8327 partial[0] = M;
8328 partial[1] = S;
8329 memcpy(partial + 2, VKQ32, DV * sizeof(float));
8330 } else {
8331 // V /= S
8332 const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
8333 ggml_vec_scale_f32(DV, VKQ32, S_inv);
8334
8335 // dst indices
8336 const int i1 = iq1;
8337 const int i2 = iq2;
8338 const int i3 = iq3;
8339
8340 // permute(0, 2, 1, 3)
8341 memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
8342 }
8343 }
8344}
8345
8346static void ggml_compute_forward_flash_attn_ext_tiled(
8347 const ggml_compute_params * params,
8348 ggml_tensor * dst,
8349 int ir0, int ir1) {
8350 const ggml_tensor * q = dst->src[0];
8351 const ggml_tensor * k = dst->src[1];
8352 const ggml_tensor * v = dst->src[2];
8353 const ggml_tensor * mask = dst->src[3];
8354 const ggml_tensor * sinks = dst->src[4];
8355
8356 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8357 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8358 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8359 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8360 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8361 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8362 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8363 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8364
8365 const int64_t DK = nek0;
8366 const int64_t DV = nev0;
8367 const int64_t N = neq1;
8368
8369 GGML_ASSERT(ne0 == DV);
8370 GGML_ASSERT(ne2 == N);
8371
8372 // input tensor rows must be contiguous
8373 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8374 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8375 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8376
8377 GGML_ASSERT(neq0 == DK);
8378 GGML_ASSERT(nek0 == DK);
8379 GGML_ASSERT(nev0 == DV);
8380
8381 GGML_ASSERT(neq1 == N);
8382
8383 // dst cannot be transposed or permuted
8384 GGML_ASSERT(nb0 == sizeof(float));
8385 GGML_ASSERT(nb0 <= nb1);
8386 GGML_ASSERT(nb1 <= nb2);
8387 GGML_ASSERT(nb2 <= nb3);
8388
8389 GGML_ASSERT(k->type == v->type);
8390 const ggml_type kv_type = k->type;
8391
8392 const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
8393 const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
8394 const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot;
8395 const size_t kv_type_size = ggml_type_size(kv_type);
8396
8397 // broadcast factors
8398 const int64_t rk2 = neq2/nek2;
8399 const int64_t rk3 = neq3/nek3;
8400
8401 const int64_t rv2 = neq2/nev2;
8402 const int64_t rv3 = neq3/nev3;
8403
8404 float scale = 1.0f;
8405 float max_bias = 0.0f;
8406 float logit_softcap = 0.0f;
8407
8408 memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
8409 memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
8410 memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
8411
8412 if (logit_softcap != 0) {
8413 scale /= logit_softcap;
8414 }
8415
8416 const uint32_t n_head = neq2;
8417 const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
8418
8419 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8420 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8421
8422 int ith = params->ith;
8423
8424 static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q;
8425 static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
8426
8427 GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
8428
8429 int ir = ir0;
8430 while (ir < ir1) {
8431 // q indices for the start of this tile
8432 const int iq3 = ir/(neq2*neq1);
8433 const int iq2 = (ir - iq3*neq2*neq1)/neq1;
8434 const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
8435
8436 // Number of valid rows in this tile:
8437 // - limited by tile size (Q_TILE_SZ)
8438 // - limited by chunk boundary (ir1 - ir)
8439 // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
8440 const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
8441 GGML_ASSERT(tile_rows > 0);
8442
8443 const uint32_t h = iq2; // head index
8444 const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
8445
8446 float S[Q_TILE_SZ];
8447 float M[Q_TILE_SZ];
8448
8449 for (int i = 0 ; i < Q_TILE_SZ; ++i) {
8450 S[i] = 0.;
8451 M[i] = -INFINITY;
8452 }
8453
8454 // Per-thread scratch layout:
8455 // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type)
8456 // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
8457 // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float)
8458 // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator)
8459 // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
8460 float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
8461
8462 void * Q_q = base;
8463 float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
8464 float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
8465 float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ;
8466 float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile
8467
8468 memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
8469 memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
8470
8471 // k indices
8472 const int ik3 = iq3 / rk3;
8473 const int ik2 = iq2 / rk2;
8474
8475 // v indices
8476 const int iv3 = iq3 / rv3;
8477 const int iv2 = iq2 / rv2;
8478
8479 for (int tq = 0; tq < tile_rows; tq++) {
8480 const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
8481 kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
8482 }
8483 // Zero-pad remaining rows
8484 for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
8485 memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
8486 }
8487
8488 for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
8489
8490 // skip the tile entirely if all the masks are -inf
8491 if (mask) {
8492 bool can_skip = true;
8493 for (int tq = 0; tq < tile_rows; tq++) {
8494 const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
8495 for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8496 mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
8497 if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
8498 can_skip = false;
8499 }
8500 }
8501 }
8502
8503 if (can_skip) {
8504 continue;
8505 }
8506 }
8507
8508 for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8509 const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
8510 for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8511 const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
8512 float s;
8513 kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
8514 KQ[tq * KV_TILE_SZ + tk] = s * scale;
8515 }
8516 }
8517
8518 if (logit_softcap != 0.0f) {
8519 ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
8520 ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
8521 }
8522
8523 if (mask) {
8524 ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
8525 }
8526
8527 bool skip[Q_TILE_SZ] = {};
8528
8529 for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8530 float * kq_row = KQ + tq * KV_TILE_SZ;
8531
8532 float tile_max;
8533 ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
8534
8535 if (tile_max == -INFINITY) {
8536 skip[tq] = true;
8537 continue;
8538 }
8539
8540 const float Mold = M[tq];
8541 const float Mnew = fmaxf(Mold, tile_max);
8542
8543 if (Mnew > Mold) {
8544 const float ms = expf(Mold - Mnew);
8545 ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8546 S[tq] *= ms;
8547 }
8548 M[tq] = Mnew;
8549
8550
8551 S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
8552 }
8553
8554 // Convert V tile to F32 first (if F16), then do MAD
8555 // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
8556 // TODO: on ARM, native f16 should be faster
8557 if (kv_type == GGML_TYPE_F16) {
8558 for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8559 const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8560 ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
8561 }
8562 for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8563 if (skip[tq]) continue;
8564 float * vkq_row = VKQ32 + tq * DV;
8565 for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8566 const float p = KQ[tq * KV_TILE_SZ + tk];
8567 ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
8568 }
8569 }
8570 } else {
8571 for (int tq = 0; tq < Q_TILE_SZ; tq++) {
8572 if (skip[tq]) continue;
8573 float * vkq_row = VKQ32 + tq * DV;
8574 for (int tk = 0; tk < KV_TILE_SZ; tk++) {
8575 const float p = KQ[tq * KV_TILE_SZ + tk];
8576 const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
8577 ggml_vec_mad_f32(DV, vkq_row, v_row, p);
8578 }
8579 }
8580 }
8581 }
8582
8583 // sinks (apply only to valid rows in the tile)
8584 if (sinks) {
8585 const float s = ((float *)((char *) sinks->data))[h];
8586
8587 for (int tq = 0; tq < tile_rows; tq++) {
8588 float ms = 1.0f;
8589 float vs = 1.0f;
8590
8591 if (s > M[tq]) {
8592 ms = expf(M[tq] - s);
8593 ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
8594 } else {
8595 vs = expf(s - M[tq]);
8596 }
8597
8598 S[tq] = S[tq] * ms + vs;
8599 }
8600 }
8601
8602 for (int tq = 0; tq < tile_rows; tq++) {
8603 // V /= S
8604 const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
8605 ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
8606
8607 // dst indices
8608 const int i1 = iq1 + tq;
8609 const int i2 = iq2;
8610 const int i3 = iq3;
8611
8612 // permute(0, 2, 1, 3)
8613 memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
8614 }
8615
8616 ir += tile_rows;
8617 }
8618}
8619
8620// Reduction function: combines partial results across KV chunks
8621// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
8622static void ggml_flash_attn_ext_reduce_partials(
8623 const ggml_compute_params * params,
8624 ggml_tensor * dst,
8625 const int64_t n_chunks,
8626 const int64_t chunk_size) {
8627
8628 const ggml_tensor * q = dst->src[0];
8629 const ggml_tensor * k = dst->src[1];
8630 const ggml_tensor * v = dst->src[2];
8631
8632 const int64_t DK = k->ne[0];
8633 const int64_t DV = v->ne[0];
8634 const int64_t nek1 = k->ne[1];
8635 const int64_t n_q_heads = q->ne[2];
8636
8637 const int ith = params->ith;
8638 const int nth = params->nth;
8639
8640 const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
8641 float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
8642
8643 const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8644 const int64_t partial_size = 2 + DV;
8645 const float * partials_base = (const float *) params->wdata + partials_offset;
8646
8647 // Output layout
8648 const int64_t ne1 = dst->ne[1];
8649 const int64_t ne2 = dst->ne[2];
8650 const size_t nb1 = dst->nb[1];
8651
8652 // Each thread reduces a subset of query heads
8653 for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
8654 float M_final = -INFINITY;
8655 float S_final = 0.0f;
8656 float * VKQ_final = thread_wdata;
8657 memset(VKQ_final, 0, DV * sizeof(float));
8658
8659 // Combine partials from all chunks
8660 for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
8661 const int64_t ic_start = chunk_idx * chunk_size;
8662 if (ic_start >= nek1) continue;
8663
8664 const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
8665 const float M_chunk = partial[0];
8666 const float S_chunk = partial[1];
8667 const float * VKQ_chunk = partial + 2;
8668
8669 if (S_chunk == 0.0f) continue;
8670
8671 const float M_new = fmaxf(M_final, M_chunk);
8672 const float scale_old = expf(M_final - M_new);
8673 const float scale_new = expf(M_chunk - M_new);
8674
8675 for (int64_t d = 0; d < DV; ++d) {
8676 VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
8677 }
8678 S_final = S_final * scale_old + S_chunk * scale_new;
8679 M_final = M_new;
8680 }
8681
8682 // Normalize and write to output
8683 if (S_final != 0.0f) {
8684 const float S_inv = 1.0f / S_final;
8685 ggml_vec_scale_f32(DV, VKQ_final, S_inv);
8686 }
8687 // iq1=0, iq3=0 for decode
8688 memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
8689 }
8690}
8691
8692static void ggml_compute_forward_flash_attn_ext_f16(
8693 const ggml_compute_params * params,
8694 ggml_tensor * dst) {
8695
8696 const ggml_tensor * q = dst->src[0];
8697 const ggml_tensor * k = dst->src[1];
8698 const ggml_tensor * v = dst->src[2];
8699
8700 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8701 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8702 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8703 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8704 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8705 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8706 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8707 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8708
8709 const int64_t DK = nek0;
8710 const int64_t DV = nev0;
8711 const int64_t N = neq1;
8712
8713
8714 GGML_ASSERT(ne0 == DV);
8715 GGML_ASSERT(ne2 == N);
8716
8717 // input tensor rows must be contiguous
8718 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8719 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8720 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8721
8722 GGML_ASSERT(neq0 == DK);
8723 GGML_ASSERT(nek0 == DK);
8724 GGML_ASSERT(nev0 == DV);
8725
8726 GGML_ASSERT(neq1 == N);
8727
8728 // dst cannot be transposed or permuted
8729 GGML_ASSERT(nb0 == sizeof(float));
8730 GGML_ASSERT(nb0 <= nb1);
8731 GGML_ASSERT(nb1 <= nb2);
8732 GGML_ASSERT(nb2 <= nb3);
8733
8734 const int ith = params->ith;
8735 const int nth = params->nth;
8736
8737 // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
8738 const bool use_ref = params->use_ref;
8739
8740 const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
8741 const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
8742
8743 if (use_split_kv_path) {
8744 const int64_t chunk_size = (nek1 + nth - 1) / nth;
8745
8746 // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
8747 const int64_t partial_size = 2 + DV;
8748 float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
8749
8750 const int64_t ic_start = ith * chunk_size;
8751 const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
8752
8753 const int64_t partial_stride = nth * partial_size;
8754 float * chunk_partials = partials_base + ith * partial_size;
8755
8756 if (ic_start < nek1) {
8757 for (int64_t q_head = 0; q_head < neq2; q_head++) {
8758 ggml_compute_forward_flash_attn_ext_f16_one_chunk(
8759 params, dst, q_head, q_head + 1, ic_start, ic_end,
8760 chunk_partials, partial_stride);
8761 }
8762 } else {
8763 for (int64_t q_head = 0; q_head < neq2; q_head++) {
8764 float * q_partials = chunk_partials + q_head * partial_stride;
8765 q_partials[0] = -INFINITY; // M
8766 q_partials[1] = 0.0f; // S
8767 }
8768 }
8769
8770 ggml_barrier(params->threadpool);
8771 ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
8772 } else {
8773
8774 // total rows in q
8775 const int64_t nr = neq1*neq2*neq3;
8776
8777 // disable for NUMA
8778 const bool disable_chunking = ggml_is_numa();
8779
8780 // 4x chunks per thread
8781 int nth_scaled = nth * 4;
8782 int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
8783 int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
8784
8785 if (nth == 1 || nchunk < nth || disable_chunking) {
8786 nchunk = nth;
8787 }
8788
8789 if (ith == 0) {
8790 ggml_threadpool_chunk_set(params->threadpool, nth);
8791 }
8792
8793 ggml_barrier(params->threadpool);
8794
8795 const int64_t dr = (nr + nchunk - 1) / nchunk;
8796
8797 static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
8798 static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
8799 const bool use_tiled = !use_ref &&
8800 (q->type == GGML_TYPE_F32 &&
8801 kv_is_f32_or_f16 &&
8802 k->type == v->type &&
8803 nek1 % KV_TILE_SZ == 0 &&
8804 neq1 >= Q_TILE_SZ);
8805
8806 int current_chunk = ith;
8807
8808 while (current_chunk < nchunk) {
8809 const int64_t ir0 = dr * current_chunk;
8810 const int64_t ir1 = MIN(ir0 + dr, nr);
8811
8812 if (use_tiled) {
8813 ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
8814 } else {
8815 ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
8816 }
8817
8818 current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
8819 }
8820 }
8821}
8822
8823void ggml_compute_forward_flash_attn_ext(
8824 const ggml_compute_params * params,
8825 ggml_tensor * dst) {
8826 switch (dst->op_params[3]) {
8827 case GGML_PREC_DEFAULT:
8828 case GGML_PREC_F32:
8829 {
8830 // uses F32 accumulators
8831 ggml_compute_forward_flash_attn_ext_f16(params, dst);
8832 } break;
8833 default:
8834 {
8835 GGML_ABORT("fatal error");
8836 }
8837 }
8838}
8839
8840// ggml_compute_forward_flash_attn_back
8841
8842static void ggml_compute_forward_flash_attn_back_f32(
8843 const ggml_compute_params * params,
8844 const bool masked,
8845 ggml_tensor * dst) {
8846
8847 const ggml_tensor * q = dst->src[0];
8848 const ggml_tensor * k = dst->src[1];
8849 const ggml_tensor * v = dst->src[2];
8850 const ggml_tensor * d = dst->src[3];
8851
8852 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8853 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8854 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8855 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8856 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8857 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8858 GGML_TENSOR_LOCALS(int64_t, ned, d, ne)
8859 GGML_TENSOR_LOCALS(size_t, nbd, d, nb)
8860 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8861 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8862
8863 const int ith = params->ith;
8864 const int nth = params->nth;
8865
8866 const int64_t D = neq0;
8867 const int64_t N = neq1;
8868 const int64_t P = nek1 - N;
8869 const int64_t M = P + N;
8870
8871 const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
8872 const int mxDM = MAX(D, Mup);
8873
8874 // GGML_ASSERT(ne0 == D);
8875 // GGML_ASSERT(ne1 == N);
8876 GGML_ASSERT(P >= 0);
8877
8878 GGML_ASSERT(nbq0 == sizeof(float));
8879 GGML_ASSERT(nbk0 == sizeof(float));
8880 GGML_ASSERT(nbv0 == sizeof(float));
8881
8882 GGML_ASSERT(neq0 == D);
8883 GGML_ASSERT(nek0 == D);
8884 GGML_ASSERT(nev1 == D);
8885 GGML_ASSERT(ned0 == D);
8886
8887 GGML_ASSERT(neq1 == N);
8888 GGML_ASSERT(nek1 == N + P);
8889 GGML_ASSERT(nev1 == D);
8890 GGML_ASSERT(ned1 == N);
8891
8892 // dst cannot be transposed or permuted
8893 GGML_ASSERT(nb0 == sizeof(float));
8894 GGML_ASSERT(nb0 <= nb1);
8895 GGML_ASSERT(nb1 <= nb2);
8896 GGML_ASSERT(nb2 <= nb3);
8897
8898 if (ith == 0) {
8899 memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
8900 }
8901 ggml_barrier(params->threadpool);
8902
8903 const int64_t elem_q = ggml_nelements(q);
8904 const int64_t elem_k = ggml_nelements(k);
8905
8906 ggml_type result_type = dst->type;
8907 GGML_ASSERT(ggml_blck_size(result_type) == 1);
8908 const size_t tsize = ggml_type_size(result_type);
8909
8910 const size_t offs_q = 0;
8911 const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
8912 const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
8913
8914 void * grad_q = (char *) dst->data;
8915 void * grad_k = (char *) dst->data + offs_k;
8916 void * grad_v = (char *) dst->data + offs_v;
8917
8918 const size_t nbgq1 = nb0*neq0;
8919 const size_t nbgq2 = nb0*neq0*neq1;
8920 const size_t nbgq3 = nb0*neq0*neq1*neq2;
8921
8922 const size_t nbgk1 = nb0*nek0;
8923 const size_t nbgk2 = nb0*nek0*nek1;
8924 const size_t nbgk3 = nb0*nek0*nek1*neq2;
8925
8926 const size_t nbgv1 = nb0*nev0;
8927 const size_t nbgv2 = nb0*nev0*nev1;
8928 const size_t nbgv3 = nb0*nev0*nev1*neq2;
8929
8930 // parallelize by k rows using ggml_vec_dot_f32
8931
8932 // total rows in k
8933 const int nr = nek2*nek3;
8934
8935 // rows per thread
8936 const int dr = (nr + nth - 1)/nth;
8937
8938 // row range for this thread
8939 const int ir0 = dr*ith;
8940 const int ir1 = MIN(ir0 + dr, nr);
8941
8942 const float scale = 1.0f/sqrtf(D);
8943
8944 //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
8945
8946 // how often k2 (and v2) is repeated in q2
8947 int nrep = neq2/nek2;
8948
8949 for (int ir = ir0; ir < ir1; ++ir) {
8950 // q indices
8951 const int ik3 = ir/(nek2);
8952 const int ik2 = ir - ik3*nek2;
8953
8954 const int iq3 = ik3;
8955 const int id3 = ik3;
8956 const int iv3 = ik3;
8957 const int iv2 = ik2;
8958
8959 for (int irep = 0; irep < nrep; ++irep) {
8960 const int iq2 = ik2 + irep*nek2;
8961 const int id2 = iq2;
8962
8963 // (ik2 + irep*nek2) % nek2 == ik2
8964 for (int iq1 = 0; iq1 < neq1; ++iq1) {
8965 const int id1 = iq1;
8966
8967 // not sure about CACHE_LINE_SIZE_F32..
8968 // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
8969 float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
8970 float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
8971
8972 for (int i = M; i < Mup; ++i) {
8973 S[i] = -INFINITY;
8974 }
8975
8976 const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
8977 for (int64_t ic = 0; ic < masked_begin; ++ic) {
8978 // k indices
8979 const int ik1 = ic;
8980
8981 // S indices
8982 const int i1 = ik1;
8983
8984 ggml_vec_dot_f32(neq0,
8985 S + i1, 0,
8986 (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
8987 (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
8988 }
8989
8990 // scale
8991 ggml_vec_scale_f32(masked_begin, S, scale);
8992
8993 for (int64_t i = masked_begin; i < M; i++) {
8994 S[i] = -INFINITY;
8995 }
8996
8997 // softmax
8998 // exclude known -INF S[..] values from max and loop
8999 // dont forget to set their SM values to zero
9000 {
9001 float max = -INFINITY;
9002 ggml_vec_max_f32(masked_begin, &max, S);
9003
9004 ggml_float sum = 0.0;
9005 {
9006#ifdef GGML_SOFT_MAX_ACCELERATE
9007 max = -max;
9008 vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
9009 vvexpf(SM, SM, &Mup);
9010 ggml_vec_sum_f32(Mup, &sum, SM);
9011#else
9012 sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
9013#endif
9014 }
9015
9016 assert(sum > 0.0);
9017
9018 sum = 1.0/sum;
9019 ggml_vec_scale_f32(masked_begin, SM, sum);
9020
9021 }
9022
9023 // step-by-step explanation
9024 {
9025 // forward-process shape grads from backward process
9026 // parallel_for ik2,ik3:
9027 // for irep:
9028 // iq2 = ik2 + irep*nek2
9029 // k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
9030 // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
9031 // v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
9032 // for iq1:
9033 // kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
9034 // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
9035 // vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
9036 // S0 = -Inf [D,1,1,1]
9037 // ~S1[i] = dot(kcur[:D,i], qcur)
9038 // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
9039 // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
9040 // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
9041 // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
9042 // ~S5[i] = dot(vcur[:,i], S4)
9043 // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
9044 // ~dst[i,iq1,iq2,iq3] = S5[i] ^
9045 // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
9046 // dst backward-/ grad[dst] = d
9047 //
9048 // output gradients with their dependencies:
9049 //
9050 // grad[kcur] = grad[S1].T @ qcur
9051 // grad[S1] = diag_mask_zero(grad[S3], P) * scale
9052 // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
9053 // grad[S4] = grad[S5] @ vcur
9054 // grad[S4] = d[:D,id1,id2,id3] @ vcur
9055 // grad[qcur] = grad[S1] @ kcur
9056 // grad[vcur] = grad[S5].T @ S4
9057 // grad[vcur] = d[:D,id1,id2,id3].T @ S4
9058 //
9059 // in post-order:
9060 //
9061 // S1 = qcur @ kcur.T
9062 // S2 = S1 * scale
9063 // S3 = diag_mask_inf(S2, P)
9064 // S4 = softmax(S3)
9065 // grad[S4] = d[:D,id1,id2,id3] @ vcur
9066 // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
9067 // grad[S1] = diag_mask_zero(grad[S3], P) * scale
9068 // grad[qcur] = grad[S1] @ kcur
9069 // grad[kcur] = grad[S1].T @ qcur
9070 // grad[vcur] = d[:D,id1,id2,id3].T @ S4
9071 //
9072 // using less variables (SM=S4):
9073 //
9074 // S = diag_mask_inf(qcur @ kcur.T * scale, P)
9075 // SM = softmax(S)
9076 // S = d[:D,iq1,iq2,iq3] @ vcur
9077 // dot_SM_gradSM = dot(SM, S)
9078 // S = SM * (S - dot(SM, S))
9079 // S = diag_mask_zero(S, P) * scale
9080 //
9081 // grad[q][:D,iq1,iq2,iq3] += S @ kcur
9082 // grad[k][:D,:M,ik2,ik3] += S.T @ qcur
9083 // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
9084 }
9085
9086 // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
9087 // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
9088 // for ic:
9089 // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
9090 // exclude known future zero S[..] values from operation
9091 ggml_vec_set_f32(masked_begin, S, 0);
9092 for (int64_t ic = 0; ic < D; ++ic) {
9093 ggml_vec_mad_f32(masked_begin,
9094 S,
9095 (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
9096 *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
9097 }
9098
9099 // S = SM * (S - dot(SM, S))
9100 float dot_SM_gradSM = 0;
9101 ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
9102 ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
9103 ggml_vec_mul_f32 (masked_begin, S, S, SM);
9104
9105 // S = diag_mask_zero(S, P) * scale
9106 // already done by above ggml_vec_set_f32
9107
9108 // exclude known zero S[..] values from operation
9109 ggml_vec_scale_f32(masked_begin, S, scale);
9110
9111 // S shape [M,1]
9112 // SM shape [M,1]
9113 // kcur shape [D,M]
9114 // qcur shape [D,1]
9115 // vcur shape [M,D]
9116
9117 // grad[q][:D,iq1,iq2,iq3] += S @ kcur
9118 // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
9119 // for ic:
9120 // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
9121 // exclude known zero S[..] values from loop
9122 for (int64_t ic = 0; ic < masked_begin; ++ic) {
9123 ggml_vec_mad_f32(D,
9124 (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
9125 (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
9126 S[ic]);
9127 }
9128
9129 // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
9130 // for ic:
9131 // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
9132 // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
9133 // exclude known zero S[..] values from loop
9134 for (int64_t ic = 0; ic < masked_begin; ++ic) {
9135 ggml_vec_mad_f32(D,
9136 (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
9137 (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
9138 S[ic]);
9139 }
9140
9141 // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
9142 // for ic:
9143 // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
9144 // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
9145 // exclude known zero SM[..] values from mad
9146 for (int64_t ic = 0; ic < D; ++ic) {
9147 ggml_vec_mad_f32(masked_begin,
9148 (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
9149 SM,
9150 *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
9151 }
9152 }
9153 }
9154 }
9155}
9156
9157void ggml_compute_forward_flash_attn_back(
9158 const ggml_compute_params * params,
9159 const bool masked,
9160 ggml_tensor * dst) {
9161
9162 const ggml_tensor * q = dst->src[0];
9163
9164 switch (q->type) {
9165 case GGML_TYPE_F32:
9166 {
9167 ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
9168 } break;
9169 default:
9170 {
9171 GGML_ABORT("fatal error");
9172 }
9173 }
9174}
9175
9176// ggml_compute_forward_ssm_conv
9177
9178static void ggml_compute_forward_ssm_conv_f32(
9179 const ggml_compute_params * params,
9180 ggml_tensor * dst) {
9181 const ggml_tensor * src0 = dst->src[0]; // conv_x
9182 const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
9183
9184 const int ith = params->ith;
9185 const int nth = params->nth;
9186
9187 const int nc = src1->ne[0]; // d_conv
9188 const int ncs = src0->ne[0]; // d_conv - 1 + n_t
9189 const int nr = src0->ne[1]; // d_inner
9190 const int n_t = dst->ne[1]; // tokens per sequence
9191 const int n_s = dst->ne[2]; // number of sequences in the batch
9192
9193 GGML_ASSERT( dst->ne[0] == nr);
9194 GGML_ASSERT(src0->nb[0] == sizeof(float));
9195 GGML_ASSERT(src1->nb[0] == sizeof(float));
9196 GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
9197
9198 // rows per thread
9199 const int dr = (nr + nth - 1)/nth;
9200
9201 // row range for this thread
9202 const int ir0 = dr*ith;
9203 const int ir1 = MIN(ir0 + dr, nr);
9204 const int ir = ir1 - ir0;
9205
9206 for (int i3 = 0; i3 < n_s; ++i3) {
9207 for (int i2 = 0; i2 < n_t; ++i2) {
9208 // {d_conv - 1 + n_t, d_inner, n_seqs}
9209 // sliding window
9210 const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
9211 const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
9212 float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
9213
9214 // TODO: transpose the output for smaller strides for big batches?
9215 // d_inner
9216 for (int i1 = 0; i1 < ir; ++i1) {
9217 // rowwise dot product
9218 // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
9219 float sumf = 0.0f;
9220
9221 // d_conv
9222 for (int i0 = 0; i0 < nc; ++i0) {
9223 sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
9224 }
9225 x[i1] = sumf;
9226 }
9227 }
9228 }
9229}
9230
9231void ggml_compute_forward_ssm_conv(
9232 const ggml_compute_params * params,
9233 ggml_tensor * dst) {
9234 switch (dst->src[0]->type) {
9235 case GGML_TYPE_F32:
9236 {
9237 ggml_compute_forward_ssm_conv_f32(params, dst);
9238 } break;
9239 default:
9240 {
9241 GGML_ABORT("fatal error");
9242 }
9243 }
9244}
9245
9246// ggml_compute_forward_ssm_scan
9247
9248static void ggml_compute_forward_ssm_scan_f32(
9249 const ggml_compute_params * params,
9250 ggml_tensor * dst) {
9251 const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+}
9252 const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs}
9253 const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
9254 const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head}
9255 const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs}
9256 const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs}
9257 const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
9258
9259 const int ith = params->ith;
9260 const int nth = params->nth;
9261
9262 const int64_t nc = src0->ne[0]; // d_state
9263 const int64_t nr = src0->ne[1]; // dim
9264 const int64_t nh = src1->ne[1]; // n_head
9265 const int64_t ng = src4->ne[1];
9266 const int64_t nt = src1->ne[2]; // number of tokens per sequence
9267 const int64_t ns = src1->ne[3]; // number of sequences in the batch
9268
9269 // can't use ggml_nbytes because src1 is not necessarily contiguous
9270 const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
9271
9272 GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
9273 GGML_ASSERT(src0->nb[0] == sizeof(float));
9274 GGML_ASSERT(src1->nb[0] == sizeof(float));
9275 GGML_ASSERT(src2->nb[0] == sizeof(float));
9276 GGML_ASSERT(src3->nb[0] == sizeof(float));
9277 GGML_ASSERT(src4->nb[0] == sizeof(float));
9278 GGML_ASSERT(src5->nb[0] == sizeof(float));
9279 GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
9280 GGML_ASSERT(nh % ng == 0);
9281
9282 // heads per thread
9283 const int dh = (nh + nth - 1)/nth;
9284
9285 // head range for this thread
9286 const int ih0 = dh*ith;
9287 const int ih1 = MIN(ih0 + dh, nh);
9288
9289 const int32_t * ids = (const int32_t *) src6->data;
9290
9291 for (int i3 = 0; i3 < ns; ++i3) {
9292 const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
9293 float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
9294
9295 for (int i2 = 0; i2 < nt; ++i2) {
9296 const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
9297 const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
9298 const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
9299 const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
9300 const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
9301 float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
9302
9303 if (src3->ne[0] == 1) {
9304 // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
9305
9306 // n_head
9307 for (int h = ih0; h < ih1; ++h) {
9308 // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9309 const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
9310 const float dA = expf(dt_soft_plus * A[h]);
9311 const int g = h / (nh / ng); // repeat_interleave
9312
9313 // dim
9314 for (int i1 = 0; i1 < nr; ++i1) {
9315 const int ii = i1 + h*nr;
9316 const float x_dt = x[ii] * dt_soft_plus;
9317 float sumf = 0.0f;
9318#if defined(GGML_SIMD)
9319 #if defined(__ARM_FEATURE_SVE)
9320 const int ggml_f32_epr = svcntw();
9321 const int ggml_f32_step = 1 * ggml_f32_epr;
9322
9323 const int np = (nc & ~(ggml_f32_step - 1));
9324
9325 GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
9326
9327 GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
9328 GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
9329
9330 for (int i = 0; i < np; i += ggml_f32_step) {
9331 // TODO: maybe unroll more?
9332 for (int j = 0; j < 1; j++) {
9333 GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
9334 GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
9335 GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
9336
9337 t0 = GGML_F32_VEC_MUL(t0, adA);
9338 t1 = GGML_F32_VEC_MUL(t1, axdt);
9339
9340 t0 = GGML_F32_VEC_ADD(t0, t1);
9341
9342 sum = GGML_F32_VEC_FMA(sum, t0, t2);
9343
9344 GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
9345 }
9346 }
9347
9348 sumf = GGML_F32xt_REDUCE_ONE(sum);
9349 #elif defined(__riscv_v_intrinsic)
9350 // todo: RVV implementation
9351 const int np = 0;
9352 #else
9353 const int np = (nc & ~(GGML_F32_STEP - 1));
9354
9355 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
9356
9357 GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
9358 GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
9359
9360 GGML_F32_VEC ax[GGML_F32_ARR];
9361 GGML_F32_VEC ay[GGML_F32_ARR];
9362 GGML_F32_VEC az[GGML_F32_ARR];
9363
9364 for (int i = 0; i < np; i += GGML_F32_STEP) {
9365 for (int j = 0; j < GGML_F32_ARR; j++) {
9366 ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
9367 ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
9368 az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
9369
9370 ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
9371 ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
9372
9373 ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
9374
9375 sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
9376
9377 GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
9378 }
9379 }
9380
9381 // reduce sum0..sum3 to sum0
9382 GGML_F32_VEC_REDUCE(sumf, sum);
9383 #endif
9384#else
9385 const int np = 0;
9386#endif
9387 // d_state
9388 for (int i0 = np; i0 < nc; ++i0) {
9389 const int i = i0 + ii*nc;
9390 const int ig = i0 + g*nc;
9391 // state = prev_state * dA + dB * x
9392 const float state = (s0[i] * dA) + (B[ig] * x_dt);
9393 // y = rowwise_dotprod(state, C)
9394 sumf += state * C[ig];
9395 s[i] = state;
9396 }
9397 y[ii] = sumf;
9398 }
9399 }
9400 } else {
9401 // Mamba-1 has an element-wise decay factor for the states
9402
9403 // n_head
9404 for (int h = ih0; h < ih1; ++h) {
9405 // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
9406 const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
9407 const int g = h / (nh / ng); // repeat_interleave
9408
9409 // dim
9410 for (int i1 = 0; i1 < nr; ++i1) {
9411 const int ii = i1 + h*nr;
9412 const float x_dt = x[ii] * dt_soft_plus;
9413#if defined(__ARM_FEATURE_SVE)
9414 svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
9415 svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
9416 svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
9417
9418 // d_state
9419 // TODO: what happens when (d_state % svcntw()) != 0?
9420 for (int64_t k = 0; k < nc; k += svcntw()) {
9421 svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
9422 svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
9423 svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
9424 svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
9425
9426 svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
9427 t1 = exp_ps_sve(svptrue_b32(), t1);
9428 svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
9429
9430 vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
9431 r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
9432
9433 GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
9434 }
9435 y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
9436#else
9437 float sumf = 0.0f;
9438 // NOTE: can't really use GGML_SIMD here because d_state is usually 16
9439 // and also because expf is used within the loop.
9440 // d_state
9441 for (int i0 = 0; i0 < nc; ++i0) {
9442 const int i = i0 + ii*nc;
9443 const int ig = i0 + g*nc;
9444 // state = prev_state * dA + dB * x
9445 const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
9446 // y = rowwise_dotprod(state, C)
9447 sumf += state * C[ig];
9448 s[i] = state;
9449 }
9450 y[ii] = sumf;
9451#endif
9452 }
9453 }
9454 }
9455 // use the output as the source when it's not the first token-wise iteration
9456 s0 = s;
9457 }
9458 }
9459}
9460
9461void ggml_compute_forward_ssm_scan(
9462 const ggml_compute_params * params,
9463 ggml_tensor * dst) {
9464 switch (dst->src[0]->type) {
9465 case GGML_TYPE_F32:
9466 {
9467 ggml_compute_forward_ssm_scan_f32(params, dst);
9468 } break;
9469 default:
9470 {
9471 GGML_ABORT("fatal error");
9472 }
9473 }
9474}
9475
9476// ggml_compute_forward_win_part
9477
9478static void ggml_compute_forward_win_part_f32(
9479 const ggml_compute_params * params,
9480 ggml_tensor * dst) {
9481 GGML_UNUSED(params);
9482
9483 const ggml_tensor * src0 = dst->src[0];
9484
9485 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9486 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
9487
9488 const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
9489 const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
9490 const int32_t w = ((const int32_t *)(dst->op_params))[2];
9491
9492 assert(ne00 == ne0);
9493 assert(ne3 == nep0*nep1);
9494
9495 // TODO: optimize / multi-thread
9496 for (int py = 0; py < nep1; ++py) {
9497 for (int px = 0; px < nep0; ++px) {
9498 const int64_t i3 = py*nep0 + px;
9499 for (int64_t i2 = 0; i2 < ne2; ++i2) {
9500 for (int64_t i1 = 0; i1 < ne1; ++i1) {
9501 for (int64_t i0 = 0; i0 < ne0; ++i0) {
9502 const int64_t i02 = py*w + i2;
9503 const int64_t i01 = px*w + i1;
9504 const int64_t i00 = i0;
9505
9506 const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0 + i0;
9507 const int64_t j = i02*ne01*ne00 + i01*ne00 + i00;
9508
9509 if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
9510 ((float *) dst->data)[i] = 0.0f;
9511 } else {
9512 ((float *) dst->data)[i] = ((float *) src0->data)[j];
9513 }
9514 }
9515 }
9516 }
9517 }
9518 }
9519}
9520
9521void ggml_compute_forward_win_part(
9522 const ggml_compute_params * params,
9523 ggml_tensor * dst) {
9524
9525 const ggml_tensor * src0 = dst->src[0];
9526
9527 switch (src0->type) {
9528 case GGML_TYPE_F32:
9529 {
9530 ggml_compute_forward_win_part_f32(params, dst);
9531 } break;
9532 default:
9533 {
9534 GGML_ABORT("fatal error");
9535 }
9536 }
9537}
9538
9539// ggml_compute_forward_win_unpart
9540
9541static void ggml_compute_forward_win_unpart_f32(
9542 const ggml_compute_params * params,
9543 ggml_tensor * dst) {
9544 GGML_UNUSED(params);
9545
9546 const ggml_tensor * src0 = dst->src[0];
9547
9548 GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
9549 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
9550
9551 const int32_t w = ((const int32_t *)(dst->op_params))[0];
9552
9553 // padding
9554 const int px = (w - ne1%w)%w;
9555 //const int py = (w - ne2%w)%w;
9556
9557 const int npx = (px + ne1)/w;
9558 //const int npy = (py + ne2)/w;
9559
9560 assert(ne0 == ne00);
9561
9562 // TODO: optimize / multi-thread
9563 for (int64_t i2 = 0; i2 < ne2; ++i2) {
9564 for (int64_t i1 = 0; i1 < ne1; ++i1) {
9565 for (int64_t i0 = 0; i0 < ne0; ++i0) {
9566 const int ip2 = i2/w;
9567 const int ip1 = i1/w;
9568
9569 const int64_t i02 = i2%w;
9570 const int64_t i01 = i1%w;
9571 const int64_t i00 = i0;
9572
9573 const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
9574 const int64_t j = i2*ne1*ne0 + i1*ne0 + i0;
9575
9576 ((float *) dst->data)[j] = ((float *) src0->data)[i];
9577 }
9578 }
9579 }
9580}
9581
9582void ggml_compute_forward_win_unpart(
9583 const ggml_compute_params * params,
9584 ggml_tensor * dst) {
9585
9586 const ggml_tensor * src0 = dst->src[0];
9587
9588 switch (src0->type) {
9589 case GGML_TYPE_F32:
9590 {
9591 ggml_compute_forward_win_unpart_f32(params, dst);
9592 } break;
9593 default:
9594 {
9595 GGML_ABORT("fatal error");
9596 }
9597 }
9598}
9599
9600//gmml_compute_forward_unary
9601
9602void ggml_compute_forward_unary(
9603 const ggml_compute_params * params,
9604 ggml_tensor * dst) {
9605
9606 const ggml_unary_op op = ggml_get_unary_op(dst);
9607
9608 switch (op) {
9609 case GGML_UNARY_OP_ABS:
9610 {
9611 ggml_compute_forward_abs(params, dst);
9612 } break;
9613 case GGML_UNARY_OP_SGN:
9614 {
9615 ggml_compute_forward_sgn(params, dst);
9616 } break;
9617 case GGML_UNARY_OP_NEG:
9618 {
9619 ggml_compute_forward_neg(params, dst);
9620 } break;
9621 case GGML_UNARY_OP_STEP:
9622 {
9623 ggml_compute_forward_step(params, dst);
9624 } break;
9625 case GGML_UNARY_OP_TANH:
9626 {
9627 ggml_compute_forward_tanh(params, dst);
9628 } break;
9629 case GGML_UNARY_OP_ELU:
9630 {
9631 ggml_compute_forward_elu(params, dst);
9632 } break;
9633 case GGML_UNARY_OP_RELU:
9634 {
9635 ggml_compute_forward_relu(params, dst);
9636 } break;
9637 case GGML_UNARY_OP_SIGMOID:
9638 {
9639 ggml_compute_forward_sigmoid(params, dst);
9640 } break;
9641 case GGML_UNARY_OP_GELU:
9642 {
9643 ggml_compute_forward_gelu(params, dst);
9644 } break;
9645 case GGML_UNARY_OP_GELU_ERF:
9646 {
9647 ggml_compute_forward_gelu_erf(params, dst);
9648 } break;
9649 case GGML_UNARY_OP_GELU_QUICK:
9650 {
9651 ggml_compute_forward_gelu_quick(params, dst);
9652 } break;
9653 case GGML_UNARY_OP_SILU:
9654 {
9655 ggml_compute_forward_silu(params, dst);
9656 } break;
9657 case GGML_UNARY_OP_HARDSWISH:
9658 {
9659 ggml_compute_forward_hardswish(params, dst);
9660 } break;
9661 case GGML_UNARY_OP_HARDSIGMOID:
9662 {
9663 ggml_compute_forward_hardsigmoid(params, dst);
9664 } break;
9665 case GGML_UNARY_OP_EXP:
9666 {
9667 ggml_compute_forward_exp(params, dst);
9668 } break;
9669 case GGML_UNARY_OP_FLOOR:
9670 {
9671 ggml_compute_forward_floor(params, dst);
9672 } break;
9673 case GGML_UNARY_OP_CEIL:
9674 {
9675 ggml_compute_forward_ceil(params, dst);
9676 } break;
9677 case GGML_UNARY_OP_ROUND:
9678 {
9679 ggml_compute_forward_round(params, dst);
9680 } break;
9681 case GGML_UNARY_OP_TRUNC:
9682 {
9683 ggml_compute_forward_trunc(params, dst);
9684 } break;
9685 case GGML_UNARY_OP_XIELU:
9686 {
9687 ggml_compute_forward_xielu(params, dst);
9688 } break;
9689 case GGML_UNARY_OP_EXPM1:
9690 {
9691 ggml_compute_forward_expm1(params, dst);
9692 } break;
9693 case GGML_UNARY_OP_SOFTPLUS:
9694 {
9695 ggml_compute_forward_softplus(params, dst);
9696 } break;
9697 default:
9698 {
9699 GGML_ABORT("fatal error");
9700 }
9701 }
9702}
9703
9704//ggml_compute_forward_glu
9705
9706void ggml_compute_forward_glu(
9707 const ggml_compute_params * params,
9708 ggml_tensor * dst) {
9709
9710 const ggml_glu_op op = ggml_get_glu_op(dst);
9711
9712 switch (op) {
9713 case GGML_GLU_OP_REGLU:
9714 {
9715 ggml_compute_forward_reglu(params, dst);
9716 } break;
9717 case GGML_GLU_OP_GEGLU:
9718 {
9719 ggml_compute_forward_geglu(params, dst);
9720 } break;
9721 case GGML_GLU_OP_SWIGLU:
9722 {
9723 ggml_compute_forward_swiglu(params, dst);
9724 } break;
9725 case GGML_GLU_OP_SWIGLU_OAI:
9726 {
9727 ggml_compute_forward_swiglu_oai(params, dst);
9728 } break;
9729 case GGML_GLU_OP_GEGLU_ERF:
9730 {
9731 ggml_compute_forward_geglu_erf(params, dst);
9732 } break;
9733 case GGML_GLU_OP_GEGLU_QUICK:
9734 {
9735 ggml_compute_forward_geglu_quick(params, dst);
9736 } break;
9737 default:
9738 {
9739 GGML_ABORT("fatal error");
9740 }
9741 }
9742}
9743
9744// ggml_compute_forward_get_rel_pos
9745
9746static void ggml_compute_forward_get_rel_pos_f16(
9747 const ggml_compute_params * params,
9748 ggml_tensor * dst) {
9749 GGML_UNUSED(params);
9750
9751 const ggml_tensor * src0 = dst->src[0];
9752
9753 // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
9754
9755 GGML_TENSOR_UNARY_OP_LOCALS
9756
9757 const int64_t w = ne1;
9758
9759 ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
9760 ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data;
9761
9762 for (int64_t i2 = 0; i2 < ne2; ++i2) {
9763 for (int64_t i1 = 0; i1 < ne1; ++i1) {
9764 const int64_t pos = (w - i1 - 1) + i2;
9765 for (int64_t i0 = 0; i0 < ne0; ++i0) {
9766 dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
9767 }
9768 }
9769 }
9770}
9771
9772void ggml_compute_forward_get_rel_pos(
9773 const ggml_compute_params * params,
9774 ggml_tensor * dst) {
9775
9776 const ggml_tensor * src0 = dst->src[0];
9777
9778 switch (src0->type) {
9779 case GGML_TYPE_F16:
9780 case GGML_TYPE_BF16:
9781 {
9782 ggml_compute_forward_get_rel_pos_f16(params, dst);
9783 } break;
9784 default:
9785 {
9786 GGML_ABORT("fatal error");
9787 }
9788 }
9789}
9790
9791// ggml_compute_forward_add_rel_pos
9792
9793static void ggml_compute_forward_add_rel_pos_f32(
9794 const ggml_compute_params * params,
9795 ggml_tensor * dst) {
9796
9797 const ggml_tensor * src0 = dst->src[0];
9798 const ggml_tensor * src1 = dst->src[1];
9799 const ggml_tensor * src2 = dst->src[2];
9800
9801 const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
9802 if (!inplace) {
9803 if (params->ith == 0) {
9804 memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
9805 }
9806 ggml_barrier(params->threadpool);
9807 }
9808 // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
9809
9810 float * src1_data = (float *) src1->data;
9811 float * src2_data = (float *) src2->data;
9812 float * dst_data = (float *) dst->data;
9813
9814 const int64_t ne10 = src1->ne[0];
9815 const int64_t ne11 = src1->ne[1];
9816 const int64_t ne12 = src1->ne[2];
9817 const int64_t ne13 = src1->ne[3];
9818
9819 const int ith = params->ith;
9820 const int nth = params->nth;
9821
9822 // total patches in dst
9823 const int np = ne13;
9824
9825 // patches per thread
9826 const int dp = (np + nth - 1)/nth;
9827
9828 // patch range for this thread
9829 const int ip0 = dp*ith;
9830 const int ip1 = MIN(ip0 + dp, np);
9831
9832 for (int64_t i13 = ip0; i13 < ip1; ++i13) {
9833 for (int64_t i12 = 0; i12 < ne12; ++i12) {
9834 for (int64_t i11 = 0; i11 < ne11; ++i11) {
9835 const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
9836 for (int64_t i10 = 0; i10 < ne10; ++i10) {
9837 const int64_t jp0 = jp1 + i10;
9838 const float src1_e = src1_data[jp0];
9839 const float src2_e = src2_data[jp0];
9840
9841 const int64_t jdh = jp0 * ne10;
9842 const int64_t jdw = jdh - (ne10 - 1) * i10;
9843
9844 for (int64_t j = 0; j < ne10; ++j) {
9845 dst_data[jdh + j ] += src2_e;
9846 dst_data[jdw + j*ne10] += src1_e;
9847 }
9848 }
9849 }
9850 }
9851 }
9852}
9853
9854void ggml_compute_forward_add_rel_pos(
9855 const ggml_compute_params * params,
9856 ggml_tensor * dst) {
9857
9858 const ggml_tensor * src0 = dst->src[0];
9859
9860 switch (src0->type) {
9861 case GGML_TYPE_F32:
9862 {
9863 ggml_compute_forward_add_rel_pos_f32(params, dst);
9864 } break;
9865 default:
9866 {
9867 GGML_ABORT("fatal error");
9868 }
9869 }
9870}
9871
9872// ggml_compute_forward_rwkv_wkv6
9873
9874static void ggml_compute_forward_rwkv_wkv6_f32(
9875 const ggml_compute_params * params,
9876 ggml_tensor * dst) {
9877 const int64_t T = dst->src[1]->ne[2];
9878 const int64_t C = dst->ne[0];
9879 const int64_t HEADS = dst->src[1]->ne[1];
9880 const int64_t n_seqs = dst->src[5]->ne[1];
9881 const int64_t head_size = C / HEADS;
9882
9883 float * dst_data = (float *) dst->data;
9884 float * state = ((float *) dst->data) + C * T;
9885
9886 const int ith = params->ith;
9887 const int nth = params->nth;
9888
9889 if (ith >= HEADS) {
9890 return;
9891 }
9892
9893 const int h_start = (HEADS * ith) / nth;
9894 const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
9895 (HEADS * (ith + 1)) / nth : HEADS;
9896
9897 float * k = (float *) dst->src[0]->data;
9898 float * v = (float *) dst->src[1]->data;
9899 float * r = (float *) dst->src[2]->data;
9900 float * time_faaaa = (float *) dst->src[3]->data;
9901 float * time_decay = (float *) dst->src[4]->data;
9902
9903 size_t t_stride = HEADS * head_size; // Same to C
9904
9905 size_t h_stride = C / HEADS;
9906 GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
9907 size_t h_stride_2d = head_size * head_size;
9908
9909 if (ith == 0) {
9910 memset(dst_data, 0, T * C * sizeof(float));
9911 }
9912 ggml_barrier(params->threadpool);
9913
9914
9915 #if defined(__AVX__) && !defined(__AVX512F__)
9916 #define GGML_F32X GGML_F32x8
9917 #define GGML_F32X_SET1 GGML_F32x8_SET1
9918 #define GGML_F32X_LOAD GGML_F32x8_LOAD
9919 #define GGML_F32X_STORE GGML_F32x8_STORE
9920 #define GGML_F32X_MUL GGML_F32x8_MUL
9921 #define GGML_F32X_FMA GGML_F32x8_FMA
9922 #define WKV_VECTOR_SIZE 8
9923 #elif defined(__AVX512F__)
9924 #define GGML_F32X GGML_F32x16
9925 #define GGML_F32X_SET1 GGML_F32x16_SET1
9926 #define GGML_F32X_LOAD GGML_F32x16_LOAD
9927 #define GGML_F32X_STORE GGML_F32x16_STORE
9928 #define GGML_F32X_MUL GGML_F32x16_MUL
9929 #define GGML_F32X_FMA GGML_F32x16_FMA
9930 #define WKV_VECTOR_SIZE 16
9931 #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
9932 #define GGML_F32X GGML_F32xt
9933 #define GGML_F32X_SET1 GGML_F32xt_SET1
9934 #define GGML_F32X_LOAD GGML_F32xt_LOAD
9935 #define GGML_F32X_STORE GGML_F32xt_STORE
9936 #define GGML_F32X_MUL GGML_F32xt_MUL
9937 #define GGML_F32X_FMA GGML_F32xt_FMA
9938 #define WKV_VECTOR_SIZE 8
9939 #elif defined(__ARM_NEON) && defined(__aarch64__)
9940 #define GGML_F32X GGML_F32x4
9941 #define GGML_F32X_SET1 GGML_F32x4_SET1
9942 #define GGML_F32X_LOAD GGML_F32x4_LOAD
9943 #define GGML_F32X_STORE GGML_F32x4_STORE
9944 #define GGML_F32X_MUL GGML_F32x4_MUL
9945 #define GGML_F32X_FMA GGML_F32x4_FMA
9946 #define WKV_VECTOR_SIZE 4
9947 #endif
9948
9949 #ifdef WKV_VECTOR_SIZE
9950 int wkv_vector_size;
9951 #if defined(__ARM_FEATURE_SVE)
9952 wkv_vector_size = svcntw();
9953 #else
9954 wkv_vector_size = WKV_VECTOR_SIZE;
9955 #endif
9956 const int64_t vec_count = head_size / wkv_vector_size;
9957
9958 for (int64_t t = 0; t < T; t++) {
9959 size_t t_offset = t * t_stride;
9960 size_t state_offset = head_size * C * (t / (T / n_seqs));
9961 float * state_cur = state + state_offset;
9962 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
9963
9964 for (int64_t h = h_start; h < h_end; h++) {
9965 size_t h_offset = h * h_stride;
9966 size_t t_h_offset = t_offset + h_offset;
9967 size_t h_2d_offset = h * h_stride_2d;
9968
9969 for (int64_t i = 0; i < head_size; i++) {
9970 size_t t_h_i_offset = t_h_offset + i;
9971 size_t h_i_offset = h_offset + i;
9972 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
9973
9974 float k_val = k[t_h_i_offset];
9975 float r_val = r[t_h_i_offset];
9976 float time_faaaa_val = time_faaaa[h_i_offset];
9977 float time_decay_val = time_decay[t_h_i_offset];
9978
9979 // Broadcast scalar values to vectors
9980 GGML_F32X k_vec = GGML_F32X_SET1(k_val);
9981 GGML_F32X r_vec = GGML_F32X_SET1(r_val);
9982 GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
9983 GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
9984
9985 for (int64_t j = 0; j < vec_count; j++) {
9986 size_t base_j = j * wkv_vector_size;
9987 size_t t_h_j_offset = t_h_offset + base_j;
9988 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
9989
9990 // Load x elements at once
9991 GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
9992 GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
9993 GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
9994
9995 // Compute kv = v * k
9996 GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
9997
9998 // Compute temp = kv * time_faaaa + prev_state
9999 GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
10000
10001 // Update dst: dst += temp * r
10002 dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
10003 GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
10004
10005 // Update state: state = prev_state * time_decay + kv
10006 GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
10007 GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
10008 }
10009
10010 // Handle remaining elements, this will not be used.
10011 for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
10012 size_t t_h_j_offset = t_h_offset + j;
10013 size_t h_2d_i_j_offset = h_2d_i_offset + j;
10014 float v_val = v[t_h_j_offset];
10015 float kv_val = v_val * k_val;
10016 float prev_state_val = state_prev[h_2d_i_j_offset];
10017 float temp_val = kv_val * time_faaaa_val + prev_state_val;
10018 dst_data[t_h_j_offset] += temp_val * r_val;
10019 state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
10020 }
10021 }
10022 }
10023 }
10024
10025 #else
10026 // basically fused operations:
10027 // dst = r @ (time_faaaa * (k @ v) + state),
10028 // state = time_decay * state + (k @ v),
10029 // recursive through each token
10030 for (int64_t t = 0; t < T; t++) {
10031 size_t t_offset = t * t_stride;
10032 size_t state_offset = head_size * C * (t / (T / n_seqs));
10033 float * state_cur = state + state_offset;
10034 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
10035
10036 for (int64_t h = h_start; h < h_end; h++) {
10037 size_t h_offset = h * h_stride;
10038 size_t t_h_offset = t_offset + h_offset;
10039 size_t h_2d_offset = h * h_stride_2d;
10040
10041 for (int64_t i = 0; i < head_size; i++) {
10042 size_t t_h_i_offset = t_h_offset + i;
10043 size_t h_i_offset = h_offset + i;
10044 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10045
10046 float k_val = k[t_h_i_offset];
10047 float r_val = r[t_h_i_offset];
10048 float time_faaaa_val = time_faaaa[h_i_offset];
10049 // RWKV v6: different time_decay for each token.
10050 float time_decay_val = time_decay[t_h_i_offset];
10051
10052 for (int64_t j = 0; j < head_size; j++) {
10053 size_t t_h_j_offset = t_h_offset + j;
10054 size_t h_2d_i_j_offset = h_2d_i_offset + j;
10055
10056 float v_val = v[t_h_j_offset];
10057 float kv_val = v_val * k_val;
10058 float prev_state_val = state_prev[h_2d_i_j_offset];
10059 float temp_val = kv_val * time_faaaa_val + prev_state_val;
10060 dst_data[t_h_j_offset] += temp_val * r_val;
10061 state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
10062 }
10063 }
10064 }
10065 }
10066 #endif
10067}
10068
10069
10070void ggml_compute_forward_rwkv_wkv6(
10071 const ggml_compute_params * params,
10072 ggml_tensor * dst) {
10073
10074 const ggml_tensor * src0 = dst->src[0];
10075
10076 switch (src0->type) {
10077 case GGML_TYPE_F32:
10078 {
10079 ggml_compute_forward_rwkv_wkv6_f32(params, dst);
10080 } break;
10081 default:
10082 {
10083 GGML_ABORT("fatal error");
10084 }
10085 }
10086}
10087
10088// ggml_compute_forward_gla
10089
10090static void ggml_compute_forward_gla_f32(
10091 const ggml_compute_params * params,
10092 ggml_tensor * dst) {
10093 const int64_t T = dst->src[1]->ne[2];
10094 const int64_t C = dst->ne[0];
10095 const int64_t HEADS = dst->src[1]->ne[1];
10096 const int64_t n_seqs = dst->src[4]->ne[1];
10097 const int64_t head_size = C / HEADS;
10098 const float scale = ggml_get_op_params_f32(dst, 0);
10099
10100 float * dst_data = (float *) dst->data;
10101 float * state = ((float *) dst->data) + C * T;
10102
10103 const int ith = params->ith;
10104 const int nth = params->nth;
10105
10106 if (ith >= HEADS) {
10107 return;
10108 }
10109
10110 const int h_start = (HEADS * ith) / nth;
10111 const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10112 (HEADS * (ith + 1)) / nth : HEADS;
10113
10114 float * k = (float *) dst->src[0]->data;
10115 float * v = (float *) dst->src[1]->data;
10116 float * q = (float *) dst->src[2]->data;
10117 float * g = (float *) dst->src[3]->data;
10118
10119 size_t t_stride = HEADS * head_size; // Same to C
10120
10121 size_t h_stride = C / HEADS;
10122 GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10123 size_t h_stride_2d = head_size * head_size;
10124
10125 if (ith == 0) {
10126 memset(dst_data, 0, T * C * sizeof(float));
10127 }
10128 ggml_barrier(params->threadpool);
10129
10130
10131 #if defined(__AVX__) && !defined(__AVX512F__)
10132 #define GGML_F32X GGML_F32x8
10133 #define GGML_F32X_SET1 GGML_F32x8_SET1
10134 #define GGML_F32X_LOAD GGML_F32x8_LOAD
10135 #define GGML_F32X_STORE GGML_F32x8_STORE
10136 #define GGML_F32X_MUL GGML_F32x8_MUL
10137 #define GGML_F32X_FMA GGML_F32x8_FMA
10138 #define GLA_VECTOR_SIZE 8
10139 #elif defined(__AVX512F__)
10140 #define GGML_F32X GGML_F32x16
10141 #define GGML_F32X_SET1 GGML_F32x16_SET1
10142 #define GGML_F32X_LOAD GGML_F32x16_LOAD
10143 #define GGML_F32X_STORE GGML_F32x16_STORE
10144 #define GGML_F32X_MUL GGML_F32x16_MUL
10145 #define GGML_F32X_FMA GGML_F32x16_FMA
10146 #define GLA_VECTOR_SIZE 16
10147 #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
10148 #define GGML_F32X GGML_F32xt
10149 #define GGML_F32X_SET1 GGML_F32xt_SET1
10150 #define GGML_F32X_LOAD GGML_F32xt_LOAD
10151 #define GGML_F32X_STORE GGML_F32xt_STORE
10152 #define GGML_F32X_MUL GGML_F32xt_MUL
10153 #define GGML_F32X_FMA GGML_F32xt_FMA
10154 #define GLA_VECTOR_SIZE 8
10155 #elif defined(__ARM_NEON) && defined(__aarch64__)
10156 #define GGML_F32X GGML_F32x4
10157 #define GGML_F32X_SET1 GGML_F32x4_SET1
10158 #define GGML_F32X_LOAD GGML_F32x4_LOAD
10159 #define GGML_F32X_STORE GGML_F32x4_STORE
10160 #define GGML_F32X_MUL GGML_F32x4_MUL
10161 #define GGML_F32X_FMA GGML_F32x4_FMA
10162 #define GLA_VECTOR_SIZE 4
10163 #endif
10164
10165 #ifdef GLA_VECTOR_SIZE
10166 int gla_vector_size;
10167 #if defined(__ARM_FEATURE_SVE)
10168 gla_vector_size = svcntw();
10169 #else
10170 gla_vector_size = GLA_VECTOR_SIZE;
10171 #endif
10172 const int64_t vec_count = head_size / gla_vector_size;
10173
10174 for (int64_t t = 0; t < T; t++) {
10175 size_t t_offset = t * t_stride;
10176 size_t state_offset = head_size * C * (t / (T / n_seqs));
10177 float * state_cur = state + state_offset;
10178 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
10179
10180 for (int64_t h = h_start; h < h_end; h++) {
10181 size_t h_offset = h * h_stride;
10182 size_t t_h_offset = t_offset + h_offset;
10183 size_t h_2d_offset = h * h_stride_2d;
10184
10185 for (int64_t i = 0; i < head_size; i++) {
10186 size_t t_h_i_offset = t_h_offset + i;
10187 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10188
10189 float k_val = k[t_h_i_offset];
10190 float q_val = q[t_h_i_offset] * scale;
10191 float g_val = g[t_h_i_offset];
10192
10193 // Broadcast scalar values to vectors
10194 GGML_F32X k_vec = GGML_F32X_SET1(k_val);
10195 GGML_F32X q_vec = GGML_F32X_SET1(q_val);
10196 GGML_F32X g_vec = GGML_F32X_SET1(g_val);
10197
10198 for (int64_t j = 0; j < vec_count; j++) {
10199 size_t base_j = j * gla_vector_size;
10200 size_t t_h_j_offset = t_h_offset + base_j;
10201 size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
10202
10203 // Load x elements at once
10204 GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
10205 GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
10206 GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
10207
10208 // Compute kv = v * k
10209 GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
10210
10211 // Compute temp = prev_state * g + kv
10212 GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
10213
10214 // Update dst: dst += temp * q
10215 dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
10216 GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
10217
10218 // Update state
10219 GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
10220 }
10221
10222 // Handle remaining elements, this will not be used.
10223 for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
10224 size_t t_h_j_offset = t_h_offset + j;
10225 size_t h_2d_i_j_offset = h_2d_i_offset + j;
10226 float v_val = v[t_h_j_offset];
10227 float kv_val = v_val * k_val;
10228 float prev_state_val = state_prev[h_2d_i_j_offset];
10229 float temp_val = kv_val + prev_state_val * g_val;
10230 dst_data[t_h_j_offset] += temp_val * q_val;
10231 state_cur[h_2d_i_j_offset] = temp_val;
10232 }
10233 }
10234 }
10235 }
10236
10237 #else
10238 for (int64_t t = 0; t < T; t++) {
10239 size_t t_offset = t * t_stride;
10240 size_t state_offset = head_size * C * (t / (T / n_seqs));
10241 float * state_cur = state + state_offset;
10242 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
10243
10244 for (int64_t h = h_start; h < h_end; h++) {
10245 size_t h_offset = h * h_stride;
10246 size_t t_h_offset = t_offset + h_offset;
10247 size_t h_2d_offset = h * h_stride_2d;
10248
10249 for (int64_t i = 0; i < head_size; i++) {
10250 size_t t_h_i_offset = t_h_offset + i;
10251 size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10252
10253 float k_val = k[t_h_i_offset];
10254 float q_val = q[t_h_i_offset] * scale;
10255 float g_val = g[t_h_i_offset];
10256
10257 for (int64_t j = 0; j < head_size; j++) {
10258 size_t t_h_j_offset = t_h_offset + j;
10259 size_t h_2d_i_j_offset = h_2d_i_offset + j;
10260
10261 float v_val = v[t_h_j_offset];
10262 float kv_val = v_val * k_val;
10263 float prev_state_val = state_prev[h_2d_i_j_offset];
10264 float temp_val = prev_state_val * g_val + kv_val;
10265 dst_data[t_h_j_offset] += temp_val * q_val;
10266 state_cur[h_2d_i_j_offset] = temp_val;
10267 }
10268 }
10269 }
10270 }
10271 #endif
10272}
10273
10274
10275void ggml_compute_forward_gla(
10276 const ggml_compute_params * params,
10277 ggml_tensor * dst) {
10278
10279 const ggml_tensor * src0 = dst->src[0];
10280
10281 switch (src0->type) {
10282 case GGML_TYPE_F32:
10283 {
10284 ggml_compute_forward_gla_f32(params, dst);
10285 } break;
10286 default:
10287 {
10288 GGML_ABORT("fatal error");
10289 }
10290 }
10291}
10292
10293static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10294 const struct ggml_tensor * src0 = dst->src[0]; // A (lower triangular)
10295 const struct ggml_tensor * src1 = dst->src[1]; // B (RHS)
10296
10297 GGML_TENSOR_BINARY_OP_LOCALS;
10298
10299 GGML_ASSERT(src0->type == GGML_TYPE_F32);
10300 GGML_ASSERT(src1->type == GGML_TYPE_F32);
10301 GGML_ASSERT(dst->type == GGML_TYPE_F32);
10302
10303 GGML_ASSERT(ne00 == ne01); // A must be square
10304 GGML_ASSERT(ne0 == ne10); // solution cols == B cols
10305 GGML_ASSERT(ne1 == ne11); // solution rows == B rows
10306
10307 GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
10308 GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
10309
10310 const int ith = params->ith;
10311 const int nth = params->nth;
10312
10313 const int64_t k = ne10; // number of RHS columns
10314 const int64_t n = ne11; // A is nรn
10315 const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
10316
10317 // chunks per thread
10318 const int64_t dr = (nr + nth - 1)/nth;
10319
10320 // chunk range for this thread
10321 const int64_t ir0 = dr*ith;
10322 const int64_t ir1 = MIN(ir0 + dr, nr);
10323
10324 const float * A = (const float *) src0->data; // [n, n, B1, B2]
10325 const float * B = (const float *) src1->data; // [n, k, B1, B2]
10326 float * X = ( float *) dst->data; // [n, k, B1, B2]
10327
10328 for (int64_t ir = ir0; ir < ir1; ++ir) {
10329 const int64_t i03 = ir/(ne02*k);
10330 const int64_t i02 = (ir - i03*ne02*k)/k;
10331 const int64_t i01 = (ir - i03*ne02*k - i02*k);
10332
10333 const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
10334 const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
10335
10336 float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
10337
10338 for (int64_t i00 = 0; i00 < n; ++i00) {
10339 float sum = 0.0f;
10340 for (int64_t t = 0; t < i00; ++t) {
10341 sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
10342 }
10343
10344 const float diag = A_batch[i00 * n + i00];
10345 assert(diag != 0.0f && "Zero diagonal in triangular matrix");
10346
10347 X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
10348 }
10349 }
10350}
10351
10352void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10353 const ggml_tensor * src0 = dst->src[0];
10354 const ggml_tensor * src1 = dst->src[1];
10355
10356 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
10357 ggml_compute_forward_solve_tri_f32(params, dst);
10358 } else {
10359 GGML_ABORT("fatal error");
10360 }
10361}
10362
10363// ggml_compute_forward_rwkv_wkv7
10364
10365static void ggml_compute_forward_rwkv_wkv7_f32(
10366 const ggml_compute_params * params,
10367 ggml_tensor * dst) {
10368 const int64_t T = dst->src[1]->ne[2];
10369 const int64_t C = dst->ne[0];
10370 const int64_t HEADS = dst->src[1]->ne[1];
10371 const int64_t n_seqs = dst->src[6]->ne[1];
10372 const int64_t head_size = C / HEADS;
10373
10374 float * dst_data = (float *) dst->data;
10375 float * state = ((float *) dst->data) + C * T;
10376
10377 const int ith = params->ith;
10378 const int nth = params->nth;
10379
10380 if (ith >= HEADS) {
10381 return;
10382 }
10383
10384 const int h_start = (HEADS * ith) / nth;
10385 const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10386 (HEADS * (ith + 1)) / nth : HEADS;
10387
10388 float * r = (float *) dst->src[0]->data;
10389 float * w = (float *) dst->src[1]->data;
10390 float * k = (float *) dst->src[2]->data;
10391 float * v = (float *) dst->src[3]->data;
10392 float * a = (float *) dst->src[4]->data;
10393 float * b = (float *) dst->src[5]->data;
10394
10395 int64_t t_stride = HEADS * head_size; // Same to C
10396
10397 int64_t h_stride = C / HEADS;
10398 GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10399 int64_t h_stride_2d = head_size * head_size;
10400
10401 #if defined(GGML_SIMD)
10402 #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
10403 // scalar Route to scalar implementation //TODO: Write SVE code and RVV code
10404 for (int64_t t = 0; t < T; t++) {
10405 int64_t t_offset = t * t_stride;
10406 int64_t state_offset = head_size * C * (t / (T / n_seqs));
10407 float * state_cur = state + state_offset;
10408 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10409
10410 for (int64_t h = h_start; h < h_end; h++) {
10411 int64_t h_offset = h * h_stride;
10412 int64_t t_h_offset = t_offset + h_offset;
10413 int64_t h_2d_offset = h * h_stride_2d;
10414
10415 for (int64_t i = 0; i < head_size; i++) {
10416 int64_t t_h_i_offset = t_h_offset + i;
10417 int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10418
10419 float v_val = v[t_h_i_offset];
10420
10421 float sa = 0, result = 0;
10422 for (int64_t j = 0; j < head_size; j++) {
10423 sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10424 }
10425
10426 for (int64_t j = 0; j < head_size; j++) {
10427 int64_t t_h_j_offset = t_h_offset + j;
10428 int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10429
10430 float r_val = r[t_h_j_offset];
10431 float w_val = w[t_h_j_offset];
10432 float k_val = k[t_h_j_offset];
10433 float b_val = b[t_h_j_offset];
10434 float kv_val = v_val * k_val;
10435 float prev_state_val = state_prev[h_2d_i_j_offset];
10436 state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10437 result += state_cur[h_2d_i_j_offset] * r_val;
10438 }
10439 dst_data[t_h_i_offset] = result;
10440 }
10441 }
10442 }
10443 #else
10444 for (int64_t t = 0; t < T; t++) {
10445 int64_t t_offset = t * t_stride;
10446 int64_t state_offset = head_size * C * (t / (T / n_seqs));
10447 float * state_cur = state + state_offset;
10448 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10449
10450 for (int64_t h = h_start; h < h_end; h++) {
10451 int64_t h_offset = h * h_stride;
10452 int64_t t_h_offset = t_offset + h_offset;
10453 int64_t h_2d_offset = h * h_stride_2d;
10454
10455 for (int64_t ii = 0; ii < head_size; ii++) {
10456 int64_t t_h_i_offset = t_h_offset + ii;
10457 int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
10458
10459 GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
10460
10461 float sa = 0;
10462 {
10463 GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10464 GGML_F32_VEC ax[GGML_F32_ARR];
10465 GGML_F32_VEC ay[GGML_F32_ARR];
10466 for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
10467 for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10468 ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
10469 ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
10470 sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
10471 }
10472 }
10473 GGML_F32_VEC_REDUCE(sa, sum);
10474 }
10475
10476 GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
10477
10478 int64_t j = 0;
10479 GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10480 for (; j < head_size; j += GGML_F32_STEP) {
10481 for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10482 int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
10483 int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
10484
10485 GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
10486 GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
10487 GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
10488 GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
10489
10490 k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
10491
10492 GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
10493 // kv + s * decay + sa * b
10494 state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
10495 state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
10496 GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
10497
10498 result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
10499 }
10500 }
10501 GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
10502
10503 // There shouldn't be left-overs though.
10504 for (; j < head_size; j++) {
10505 int64_t t_h_j_offset = t_h_offset + j;
10506 int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10507
10508 float r_val = r[t_h_j_offset];
10509 float w_val = w[t_h_j_offset];
10510 float k_val = k[t_h_j_offset];
10511 float b_val = b[t_h_j_offset];
10512 float kv_val = v[t_h_i_offset] * k_val;
10513
10514 float prev_state_val = state_prev[h_2d_i_j_offset];
10515 state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10516 dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
10517 }
10518 }
10519 }
10520 }
10521 #endif
10522 #else
10523 for (int64_t t = 0; t < T; t++) {
10524 int64_t t_offset = t * t_stride;
10525 int64_t state_offset = head_size * C * (t / (T / n_seqs));
10526 float * state_cur = state + state_offset;
10527 float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10528
10529 for (int64_t h = h_start; h < h_end; h++) {
10530 int64_t h_offset = h * h_stride;
10531 int64_t t_h_offset = t_offset + h_offset;
10532 int64_t h_2d_offset = h * h_stride_2d;
10533
10534 for (int64_t i = 0; i < head_size; i++) {
10535 int64_t t_h_i_offset = t_h_offset + i;
10536 int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10537
10538 float v_val = v[t_h_i_offset];
10539
10540 float sa = 0, result = 0;
10541 for (int64_t j = 0; j < head_size; j++) {
10542 sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10543 }
10544
10545 for (int64_t j = 0; j < head_size; j++) {
10546 int64_t t_h_j_offset = t_h_offset + j;
10547 int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10548
10549 float r_val = r[t_h_j_offset];
10550 float w_val = w[t_h_j_offset];
10551 float k_val = k[t_h_j_offset];
10552 float b_val = b[t_h_j_offset];
10553 float kv_val = v_val * k_val;
10554 float prev_state_val = state_prev[h_2d_i_j_offset];
10555 state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10556 result += state_cur[h_2d_i_j_offset] * r_val;
10557 }
10558 dst_data[t_h_i_offset] = result;
10559 }
10560 }
10561 }
10562 #endif
10563}
10564
10565
10566void ggml_compute_forward_rwkv_wkv7(
10567 const ggml_compute_params * params,
10568 ggml_tensor * dst) {
10569
10570 const ggml_tensor * src0 = dst->src[0];
10571
10572 switch (src0->type) {
10573 case GGML_TYPE_F32:
10574 {
10575 ggml_compute_forward_rwkv_wkv7_f32(params, dst);
10576 } break;
10577 default:
10578 {
10579 GGML_ABORT("fatal error");
10580 }
10581 }
10582}
10583
10584// ggml_compute_forward_map_custom1
10585
10586void ggml_compute_forward_map_custom1(
10587 const ggml_compute_params * params,
10588 ggml_tensor * dst) {
10589
10590 const ggml_tensor * a = dst->src[0];
10591
10592 struct ggml_map_custom1_op_params p;
10593 memcpy(&p, dst->op_params, sizeof(p));
10594
10595 p.fun(dst, a, params->ith, params->nth, p.userdata);
10596}
10597
10598// ggml_compute_forward_map_custom2
10599
10600void ggml_compute_forward_map_custom2(
10601 const ggml_compute_params * params,
10602 ggml_tensor * dst) {
10603
10604 const ggml_tensor * a = dst->src[0];
10605 const ggml_tensor * b = dst->src[1];
10606
10607 struct ggml_map_custom2_op_params p;
10608 memcpy(&p, dst->op_params, sizeof(p));
10609
10610 p.fun(dst, a, b, params->ith, params->nth, p.userdata);
10611}
10612
10613// ggml_compute_forward_map_custom3
10614
10615void ggml_compute_forward_map_custom3(
10616 const ggml_compute_params * params,
10617 ggml_tensor * dst) {
10618
10619 const ggml_tensor * a = dst->src[0];
10620 const ggml_tensor * b = dst->src[1];
10621 const ggml_tensor * c = dst->src[2];
10622
10623 struct ggml_map_custom3_op_params p;
10624 memcpy(&p, dst->op_params, sizeof(p));
10625
10626 p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
10627}
10628
10629// ggml_compute_forward_custom
10630
10631void ggml_compute_forward_custom(
10632 const struct ggml_compute_params * params,
10633 struct ggml_tensor * dst) {
10634
10635 struct ggml_custom_op_params p;
10636 memcpy(&p, dst->op_params, sizeof(p));
10637
10638 p.fun(dst, params->ith, params->nth, p.userdata);
10639}
10640
10641// ggml_compute_forward_cross_entropy_loss
10642
10643static void ggml_compute_forward_cross_entropy_loss_f32(
10644 const ggml_compute_params * params,
10645 ggml_tensor * dst) {
10646
10647 const ggml_tensor * src0 = dst->src[0];
10648 const ggml_tensor * src1 = dst->src[1];
10649
10650 GGML_ASSERT(src0->type == GGML_TYPE_F32);
10651 GGML_ASSERT(src1->type == GGML_TYPE_F32);
10652 GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
10653 GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
10654 GGML_ASSERT(ggml_are_same_shape(src0, src1));
10655 GGML_ASSERT(ggml_is_scalar(dst));
10656 GGML_ASSERT(dst->type == GGML_TYPE_F32);
10657
10658 // TODO: handle transposed/permuted matrices
10659 const int64_t nc = src0->ne[0];
10660 const int64_t nr = ggml_nrows(src0);
10661
10662 const int ith = params->ith;
10663 const int nth = params->nth;
10664
10665 float * sums = (float *) params->wdata;
10666 float * st = ((float *) params->wdata) + nth + ith*nc;
10667 float sum_thread = 0.0f;
10668
10669 GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
10670
10671 // rows per thread
10672 const int64_t dr = (nr + nth - 1)/nth;
10673
10674 // row range for this thread
10675 const int64_t ir0 = dr*ith;
10676 const int64_t ir1 = MIN(ir0 + dr, nr);
10677
10678 for (int64_t i1 = ir0; i1 < ir1; ++i1) {
10679 const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
10680 const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
10681
10682#ifndef NDEBUG
10683 for (int64_t i = 0; i < nc; ++i) {
10684 //printf("p[%d] = %f\n", i, p[i]);
10685 assert(!isnan(s0[i]));
10686 assert(!isnan(s1[i]));
10687 }
10688#endif
10689
10690 float max = -INFINITY;
10691 ggml_vec_max_f32(nc, &max, s0);
10692 const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
10693 assert(sum_softmax >= 0.0);
10694
10695 ggml_vec_add1_f32(nc, st, st, -sum_softmax);
10696 ggml_vec_mul_f32(nc, st, st, s1);
10697
10698 float sum_st = 0.0f;
10699 ggml_vec_sum_f32(nc, &sum_st, st);
10700 sum_thread += sum_st;
10701
10702#ifndef NDEBUG
10703 for (int64_t i = 0; i < nc; ++i) {
10704 assert(!isnan(st[i]));
10705 assert(!isinf(st[i]));
10706 }
10707#endif
10708 }
10709 sums[ith] = sum_thread;
10710 ggml_barrier(params->threadpool);
10711
10712 if (ith == 0) {
10713 float * dp = (float *) dst->data;
10714 ggml_vec_sum_f32(nth, dp, sums);
10715 dp[0] *= -1.0f / (float) nr;
10716 }
10717}
10718
10719void ggml_compute_forward_cross_entropy_loss(
10720 const ggml_compute_params * params,
10721 ggml_tensor * dst) {
10722
10723 const ggml_tensor * src0 = dst->src[0];
10724
10725 switch (src0->type) {
10726 case GGML_TYPE_F32:
10727 {
10728 ggml_compute_forward_cross_entropy_loss_f32(params, dst);
10729 } break;
10730 default:
10731 {
10732 GGML_ABORT("fatal error");
10733 }
10734 }
10735}
10736
10737// ggml_compute_forward_cross_entropy_loss_back
10738
10739static void ggml_compute_forward_cross_entropy_loss_back_f32(
10740 const ggml_compute_params * params,
10741 ggml_tensor * dst) {
10742
10743 const ggml_tensor * grad = dst->src[0]; // gradient of forward pass output
10744 const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
10745 const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
10746
10747 GGML_ASSERT(ggml_is_contiguous(dst));
10748 GGML_ASSERT(ggml_is_contiguous(src0f));
10749 GGML_ASSERT(ggml_is_contiguous(src1f));
10750 GGML_ASSERT(ggml_is_contiguous(grad));
10751 GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
10752
10753 const int64_t ith = params->ith;
10754 const int64_t nth = params->nth;
10755
10756 // TODO: handle transposed/permuted matrices
10757 const int64_t nc = src0f->ne[0];
10758 const int64_t nr = ggml_nrows(src0f);
10759
10760 // rows per thread
10761 const int64_t dr = (nr + nth - 1)/nth;
10762
10763 // row range for this thread
10764 const int64_t ir0 = dr*ith;
10765 const int64_t ir1 = MIN(ir0 + dr, nr);
10766
10767 const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
10768
10769 for (int64_t i1 = ir0; i1 < ir1; i1++) {
10770 float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
10771 const float * s0 = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
10772 const float * s1 = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
10773
10774#ifndef NDEBUG
10775 for (int64_t i = 0; i < nc; ++i) {
10776 //printf("p[%d] = %f\n", i, p[i]);
10777 assert(!isnan(s0[i]));
10778 assert(!isnan(s1[i]));
10779 }
10780#endif
10781
10782 // soft_max
10783 float max = -INFINITY;
10784 ggml_vec_max_f32(nc, &max, s0);
10785 const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
10786 assert(sum > 0.0);
10787 ggml_vec_scale_f32(nc, ds0, 1.0/sum);
10788
10789 // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
10790 ggml_vec_sub_f32(nc, ds0, ds0, s1);
10791 ggml_vec_scale_f32(nc, ds0, d_by_nr);
10792
10793#ifndef NDEBUG
10794 for (int64_t i = 0; i < nc; ++i) {
10795 assert(!isnan(ds0[i]));
10796 assert(!isinf(ds0[i]));
10797 }
10798#endif
10799 }
10800}
10801
10802void ggml_compute_forward_cross_entropy_loss_back(
10803 const ggml_compute_params * params,
10804 ggml_tensor * dst) {
10805
10806 const ggml_tensor * src0 = dst->src[0];
10807
10808 switch (src0->type) {
10809 case GGML_TYPE_F32:
10810 {
10811 ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
10812 } break;
10813 default:
10814 {
10815 GGML_ABORT("fatal error");
10816 }
10817 }
10818}
10819
10820static void ggml_compute_forward_opt_step_adamw_f32(
10821 const ggml_compute_params * params,
10822 ggml_tensor * dst) {
10823
10824 const ggml_tensor * src0 = dst->src[0];
10825 const ggml_tensor * src0_grad = dst->src[1];
10826 const ggml_tensor * src0_grad_m = dst->src[2];
10827 const ggml_tensor * src0_grad_v = dst->src[3];
10828 const ggml_tensor * adamw_params = dst->src[4];
10829
10830 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10831 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
10832 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
10833 GGML_ASSERT(ggml_nelements(adamw_params) == 7);
10834
10835 const int ith = params->ith;
10836 const int nth = params->nth;
10837
10838 const int nr = ggml_nrows(src0);
10839
10840 GGML_TENSOR_UNARY_OP_LOCALS
10841 GGML_ASSERT(nb00 == sizeof(float));
10842
10843 // rows per thread
10844 const int dr = (nr + nth - 1)/nth;
10845
10846 // row range for this thread
10847 const int ir0 = dr*ith;
10848 const int ir1 = MIN(ir0 + dr, nr);
10849
10850 const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10851
10852 const float alpha = adamw_params_ptr[0];
10853 const float beta1 = adamw_params_ptr[1];
10854 const float beta2 = adamw_params_ptr[2];
10855 const float eps = adamw_params_ptr[3];
10856 const float wd = adamw_params_ptr[4];
10857 const float beta1h = adamw_params_ptr[5];
10858 const float beta2h = adamw_params_ptr[6];
10859 const float keep = 1.f - alpha * wd;
10860 for (int ir = ir0; ir < ir1; ++ir) {
10861 const int64_t i03 = ir/(ne02*ne01);
10862 const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10863 const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10864
10865 const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
10866
10867 float * w = (float *) ((char *) src0->data + offset); // weight
10868 const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10869 float * m = (float *) ((char *) src0_grad_m->data + offset);
10870 float * v = (float *) ((char *) src0_grad_v->data + offset);
10871
10872 for (int i00 = 0; i00 < ne00; ++i00) {
10873 m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
10874 v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
10875
10876 const float mh = m[i00]*beta1h;
10877 const float vh = sqrtf(v[i00]*beta2h) + eps;
10878
10879 // The weight decay is applied independently of the Adam momenta m and v.
10880 // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10881 // See: https://arxiv.org/pdf/1711.05101v3.pdf
10882 w[i00] = w[i00] * keep - alpha * mh / vh;
10883 }
10884 }
10885}
10886
10887void ggml_compute_forward_opt_step_adamw(
10888 const ggml_compute_params * params,
10889 ggml_tensor * dst) {
10890
10891 const ggml_tensor * src0 = dst->src[0];
10892
10893 switch (src0->type) {
10894 case GGML_TYPE_F32:
10895 {
10896 ggml_compute_forward_opt_step_adamw_f32(params, dst);
10897 } break;
10898 default:
10899 {
10900 GGML_ABORT("fatal error");
10901 }
10902 }
10903}
10904
10905static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10906 const ggml_tensor * src0 = dst->src[0];
10907 const ggml_tensor * src0_grad = dst->src[1];
10908 const ggml_tensor * sgd_params = dst->src[2];
10909
10910 GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10911 GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10912
10913 const int ith = params->ith;
10914 const int nth = params->nth;
10915
10916 const int nr = ggml_nrows(src0);
10917
10918 GGML_TENSOR_UNARY_OP_LOCALS
10919 GGML_ASSERT(nb00 == sizeof(float));
10920
10921 // rows per thread
10922 const int dr = (nr + nth - 1) / nth;
10923
10924 // row range for this thread
10925 const int ir0 = dr * ith;
10926 const int ir1 = MIN(ir0 + dr, nr);
10927
10928 // using adamw param subset we care about - alpha, wd - could have a separate struct
10929 const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
10930 const float alpha = sgd_params_ptr[0];
10931 const float keep = 1.f - alpha * sgd_params_ptr[1];
10932
10933 for (int ir = ir0; ir < ir1; ++ir) {
10934 const int64_t i03 = ir / (ne02 * ne01);
10935 const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10936 const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10937
10938 const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10939
10940 float * w = (float *) ((char *) src0->data + offset); // weight
10941 const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
10942
10943 for (int i00 = 0; i00 < ne00; ++i00) {
10944 w[i00] = w[i00] * keep - alpha * g[i00];
10945 }
10946 }
10947}
10948
10949void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10950 const ggml_tensor * src0 = dst->src[0];
10951
10952 switch (src0->type) {
10953 case GGML_TYPE_F32:
10954 {
10955 ggml_compute_forward_opt_step_sgd_f32(params, dst);
10956 }
10957 break;
10958 default:
10959 {
10960 GGML_ABORT("fatal error - sgd is F32 only");
10961 }
10962 }
10963}