1#include "ops.h"
    2
    3#include "ggml-cpu.h"
    4#include "ggml-impl.h"
    5#include "binary-ops.h"
    6#include "ggml.h"
    7#include "unary-ops.h"
    8#include "vec.h"
    9
   10#include <algorithm>
   11#include <cfloat>
   12#include <cmath>
   13
   14// ggml_compute_forward_dup
   15
   16static void ggml_compute_forward_dup_same_cont(
   17        const ggml_compute_params * params,
   18        ggml_tensor * dst) {
   19
   20    const ggml_tensor * src0 = dst->src[0];
   21
   22    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
   23    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
   24    GGML_ASSERT(src0->type == dst->type);
   25
   26    const size_t nb0 = ggml_type_size(src0->type);
   27
   28    const int ith = params->ith; // thread index
   29    const int nth = params->nth; // number of threads
   30
   31    // parallelize by blocks
   32    const int nk = ggml_nelements(src0)/ggml_blck_size(src0->type);
   33    const int dr = (nk + nth - 1) / nth;
   34    const int k0 = dr * ith;
   35    const int k1 = MIN(k0 + dr, nk);
   36
   37    if (k0 < k1) {
   38        memcpy(
   39            ((char *)  dst->data + k0*nb0),
   40            ((char *) src0->data + k0*nb0),
   41            (k1 - k0) * nb0);
   42    }
   43}
   44
   45template<typename src_t, typename dst_t>
   46static void ggml_compute_forward_dup_flt(
   47        const ggml_compute_params * params,
   48        ggml_tensor * dst) {
   49
   50    const ggml_tensor * src0 = dst->src[0];
   51
   52    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
   53    GGML_ASSERT(!ggml_is_quantized(src0->type) && !ggml_is_quantized(dst->type));
   54
   55    GGML_TENSOR_UNARY_OP_LOCALS
   56
   57    const int ith = params->ith; // thread index
   58    const int nth = params->nth; // number of threads
   59
   60    // parallelize by rows
   61    const int nr = ne01;
   62    // number of rows per thread
   63    const int dr = (nr + nth - 1) / nth;
   64    // row range for this thread
   65    const int ir0 = dr * ith;
   66    const int ir1 = MIN(ir0 + dr, nr);
   67
   68    // case: type & row size equal
   69    if (src0->type == dst->type &&
   70        ne00 == ne0 &&
   71        nb00 == ggml_type_size(src0->type) && nb0 == ggml_type_size(dst->type)) {
   72        // copy by rows
   73        const size_t rs = ne00*nb00;
   74        for (int64_t i03 = 0; i03 < ne03; i03++) {
   75            for (int64_t i02 = 0; i02 < ne02; i02++) {
   76                for (int64_t i01 = ir0; i01 < ir1; i01++) {
   77                    memcpy(
   78                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
   79                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
   80                        rs);
   81                }
   82            }
   83        }
   84        return;
   85    }
   86
   87    // case: dst tensor is contiguous
   88    if (ggml_is_contiguous(dst)) {
   89        if (nb00 == sizeof(src_t)) {
   90            if constexpr (std::is_same_v<dst_t, src_t>) {
   91                // same type
   92                size_t id = 0;
   93                const size_t rs = ne00 * nb00;
   94                char * dst_ptr = (char *) dst->data;
   95
   96                for (int i03 = 0; i03 < ne03; i03++) {
   97                    for (int i02 = 0; i02 < ne02; i02++) {
   98                        id += rs * ir0;
   99                        for (int i01 = ir0; i01 < ir1; i01++) {
  100                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  101                            memcpy(dst_ptr + id, src0_ptr, rs);
  102                            id += rs;
  103                        }
  104                        id += rs * (ne01 - ir1);
  105                    }
  106                }
  107            } else {
  108                // casting between non-quantized types
  109                size_t id = 0;
  110                dst_t * dst_ptr = (dst_t *) dst->data;
  111
  112                for (int i03 = 0; i03 < ne03; i03++) {
  113                    for (int i02 = 0; i02 < ne02; i02++) {
  114                        id += ne00 * ir0;
  115                        for (int i01 = ir0; i01 < ir1; i01++) {
  116                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  117                            for (int i00 = 0; i00 < ne00; i00++) {
  118                                float tmp = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
  119                                dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
  120                                id++;
  121                            }
  122                        }
  123                        id += ne00 * (ne01 - ir1);
  124                    }
  125                }
  126            }
  127        } else {
  128            //printf("%s: this is not optimal - fix me\n", __func__);
  129
  130            size_t id = 0;
  131            dst_t * dst_ptr = (dst_t *) dst->data;
  132
  133            for (int i03 = 0; i03 < ne03; i03++) {
  134                for (int i02 = 0; i02 < ne02; i02++) {
  135                    id += ne00 * ir0;
  136                    for (int i01 = ir0; i01 < ir1; i01++) {
  137                        for (int i00 = 0; i00 < ne00; i00++) {
  138                            const src_t * src0_ptr = (src_t *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  139
  140                            float tmp = type_conversion_table<src_t>::to_f32(*src0_ptr);
  141                            dst_ptr[id] = type_conversion_table<dst_t>::from_f32(tmp);
  142                            id++;
  143                        }
  144                    }
  145                    id += ne00 * (ne01 - ir1);
  146                }
  147            }
  148        }
  149        return;
  150    }
  151
  152    // dst counters
  153    int64_t i10 = 0;
  154    int64_t i11 = 0;
  155    int64_t i12 = 0;
  156    int64_t i13 = 0;
  157
  158    if constexpr (std::is_same_v<dst_t, src_t>) {
  159        for (int64_t i03 = 0; i03 < ne03; i03++) {
  160            for (int64_t i02 = 0; i02 < ne02; i02++) {
  161                i10 += ne00 * ir0;
  162                while (i10 >= ne0) {
  163                    i10 -= ne0;
  164                    if (++i11 == ne1) {
  165                        i11 = 0;
  166                        if (++i12 == ne2) {
  167                            i12 = 0;
  168                            if (++i13 == ne3) {
  169                                i13 = 0;
  170                            }
  171                        }
  172                    }
  173                }
  174                for (int64_t i01 = ir0; i01 < ir1; i01++) {
  175                    for (int64_t i00 = 0; i00 < ne00; i00++) {
  176                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  177                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
  178
  179                        memcpy(dst_ptr, src0_ptr, sizeof(dst_t));
  180
  181                        if (++i10 == ne00) {
  182                            i10 = 0;
  183                            if (++i11 == ne01) {
  184                                i11 = 0;
  185                                if (++i12 == ne02) {
  186                                    i12 = 0;
  187                                    if (++i13 == ne03) {
  188                                        i13 = 0;
  189                                    }
  190                                }
  191                            }
  192                        }
  193                    }
  194                }
  195                i10 += ne00 * (ne01 - ir1);
  196                while (i10 >= ne0) {
  197                    i10 -= ne0;
  198                    if (++i11 == ne1) {
  199                        i11 = 0;
  200                        if (++i12 == ne2) {
  201                            i12 = 0;
  202                            if (++i13 == ne3) {
  203                                i13 = 0;
  204                            }
  205                        }
  206                    }
  207                }
  208            }
  209        }
  210
  211    } else {
  212        for (int64_t i03 = 0; i03 < ne03; i03++) {
  213            for (int64_t i02 = 0; i02 < ne02; i02++) {
  214                i10 += ne00 * ir0;
  215                while (i10 >= ne0) {
  216                    i10 -= ne0;
  217                    if (++i11 == ne1) {
  218                        i11 = 0;
  219                        if (++i12 == ne2) {
  220                            i12 = 0;
  221                            if (++i13 == ne3) {
  222                                i13 = 0;
  223                            }
  224                        }
  225                    }
  226                }
  227                for (int64_t i01 = ir0; i01 < ir1; i01++) {
  228                    for (int64_t i00 = 0; i00 < ne00; i00++) {
  229                        const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  230                              char * dst_ptr  = ((char *)  dst->data + i10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
  231
  232                        float tmp = type_conversion_table<src_t>::to_f32(*(const src_t *) src0_ptr);
  233                        *(dst_t *) dst_ptr = type_conversion_table<dst_t>::from_f32(tmp);
  234
  235                        if (++i10 == ne0) {
  236                            i10 = 0;
  237                            if (++i11 == ne1) {
  238                                i11 = 0;
  239                                if (++i12 == ne2) {
  240                                    i12 = 0;
  241                                    if (++i13 == ne3) {
  242                                        i13 = 0;
  243                                    }
  244                                }
  245                            }
  246                        }
  247                    }
  248                }
  249                i10 += ne00 * (ne01 - ir1);
  250                while (i10 >= ne0) {
  251                    i10 -= ne0;
  252                    if (++i11 == ne1) {
  253                        i11 = 0;
  254                        if (++i12 == ne2) {
  255                            i12 = 0;
  256                            if (++i13 == ne3) {
  257                                i13 = 0;
  258                            }
  259                        }
  260                    }
  261                }
  262            }
  263        }
  264    }
  265}
  266
  267
  268template<typename src_t>
  269static void ggml_compute_forward_dup_to_q(
  270        const ggml_compute_params * params,
  271        ggml_tensor * dst) {
  272
  273    const ggml_tensor * src0 = dst->src[0];
  274
  275    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  276    GGML_ASSERT(!ggml_is_quantized(src0->type));
  277
  278    GGML_TENSOR_UNARY_OP_LOCALS
  279
  280    const int ith = params->ith; // thread index
  281    const int nth = params->nth; // number of threads
  282
  283    // parallelize by rows
  284    const int nr = ne01;
  285    // number of rows per thread
  286    const int dr = (nr + nth - 1) / nth;
  287    // row range for this thread
  288    const int ir0 = dr * ith;
  289    const int ir1 = MIN(ir0 + dr, nr);
  290
  291    if (ggml_is_contiguous(dst) &&
  292            nb00 == sizeof(src_t) &&
  293            ggml_get_type_traits_cpu(dst->type)->from_float) {
  294        // casting non-quantized types --> intermediate f32 --> quantized
  295        ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
  296        float * src0_f32 = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  297
  298        size_t id = 0;
  299        size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
  300        char * dst_ptr = (char *) dst->data;
  301
  302        for (int i03 = 0; i03 < ne03; i03++) {
  303            for (int i02 = 0; i02 < ne02; i02++) {
  304                id += rs * ir0;
  305                for (int i01 = ir0; i01 < ir1; i01++) {
  306                    const src_t * src0_ptr = (src_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
  307
  308                    for (int i00 = 0; i00 < ne00; i00++) {
  309                        src0_f32[i00] = type_conversion_table<src_t>::to_f32(src0_ptr[i00]);
  310                    }
  311
  312                    quantize_row_q(src0_f32, dst_ptr + id, ne00);
  313                    id += rs;
  314                }
  315                id += rs * (ne01 - ir1);
  316            }
  317        }
  318    } else {
  319        // printf("%s %s\n", ggml_type_name(src0->type), ggml_type_name(dst->type));
  320        GGML_ABORT("not implemented");
  321    }
  322}
  323
  324// A simplified version of ggml_compute_forward_dup that doesn't do float upcasting, and just plain old memcpy.
  325static void ggml_compute_forward_dup_bytes(
  326        const ggml_compute_params * params,
  327        ggml_tensor * dst) {
  328    const ggml_tensor * src0 = dst->src[0];
  329
  330    GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
  331    GGML_ASSERT(src0->type == dst->type);
  332
  333    GGML_TENSOR_UNARY_OP_LOCALS;
  334
  335    if (ggml_is_contiguous(src0) && ggml_is_contiguous(dst)) {
  336        ggml_compute_forward_dup_same_cont(params, dst);
  337        return;
  338    }
  339
  340    const size_t type_size = ggml_type_size(src0->type);
  341
  342    const int ith = params->ith; // thread index
  343    const int nth = params->nth; // number of threads
  344
  345    // parallelize by rows
  346    const int nr = ne01;
  347    // number of rows per thread
  348    const int dr = (nr + nth - 1) / nth;
  349    // row range for this thread
  350    const int ir0 = dr * ith;
  351    const int ir1 = MIN(ir0 + dr, nr);
  352
  353    if (src0->type == dst->type &&
  354        ggml_are_same_shape(src0, dst) &&
  355        nb00 == type_size && nb0 == type_size) {
  356        // copy by rows
  357        const size_t rs = ggml_row_size(src0->type, ne00);
  358        for (int64_t i03 = 0; i03 < ne03; i03++) {
  359            for (int64_t i02 = 0; i02 < ne02; i02++) {
  360                for (int64_t i01 = ir0; i01 < ir1; i01++) {
  361                    memcpy(
  362                        ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
  363                        ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03),
  364                        rs);
  365                }
  366            }
  367        }
  368        return;
  369    }
  370
  371    if (ggml_is_contiguous(dst)) {
  372        size_t id = 0;
  373        char * dst_ptr = (char *) dst->data;
  374        const size_t rs = ne00 * type_size;
  375
  376        if (nb00 == type_size) {
  377            // src0 is contigous on first dimension, copy by rows
  378            for (int64_t i03 = 0; i03 < ne03; i03++) {
  379                for (int64_t i02 = 0; i02 < ne02; i02++) {
  380                    id += rs * ir0;
  381                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
  382                        const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
  383                        memcpy(dst_ptr + id, src0_ptr, rs);
  384                        id += rs;
  385                    }
  386                    id += rs * (ne01 - ir1);
  387                }
  388            }
  389        } else {
  390            //printf("%s: this is not optimal - fix me\n", __func__);
  391
  392            for (int64_t i03 = 0; i03 < ne03; i03++) {
  393                for (int64_t i02 = 0; i02 < ne02; i02++) {
  394                    id += rs * ir0;
  395                    for (int64_t i01 = ir0; i01 < ir1; i01++) {
  396                        for (int64_t i00 = 0; i00 < ne00; i00++) {
  397                            const char * src0_ptr = (char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03;
  398                            memcpy(dst_ptr + id, src0_ptr, type_size);
  399
  400                            id += type_size;
  401                        }
  402                    }
  403                    id += rs * (ne01 - ir1);
  404                }
  405            }
  406        }
  407
  408        return;
  409    }
  410
  411    // dst counters
  412    int64_t k10 = 0;
  413    int64_t i11 = 0;
  414    int64_t i12 = 0;
  415    int64_t i13 = 0;
  416
  417    // number of blocks in a row
  418    const int64_t nk00 = ne00 / ggml_blck_size(src0->type);
  419    const int64_t nk0  = ne0  / ggml_blck_size(dst->type);
  420
  421    for (int64_t i03 = 0; i03 < ne03; i03++) {
  422        for (int64_t i02 = 0; i02 < ne02; i02++) {
  423            k10 += nk00 * ir0;
  424            while (k10 >= nk0) {
  425                k10 -= nk0;
  426                if (++i11 == ne1) {
  427                    i11 = 0;
  428                    if (++i12 == ne2) {
  429                        i12 = 0;
  430                        if (++i13 == ne3) {
  431                            i13 = 0;
  432                        }
  433                    }
  434                }
  435            }
  436            for (int64_t i01 = ir0; i01 < ir1; i01++) {
  437                for (int64_t k00 = 0; k00 < nk00; k00++) {
  438                    const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
  439                          char * dst_ptr  = ((char *)  dst->data + k10*nb0  + i11*nb1  + i12*nb2  + i13*nb3);
  440
  441                    memcpy(dst_ptr, src0_ptr, type_size);
  442
  443                    if (++k10 == nk0) {
  444                        k10 = 0;
  445                        if (++i11 == ne1) {
  446                            i11 = 0;
  447                            if (++i12 == ne2) {
  448                                i12 = 0;
  449                                if (++i13 == ne3) {
  450                                    i13 = 0;
  451                                }
  452                            }
  453                        }
  454                    }
  455                }
  456            }
  457            k10 += nk00 * (ne01 - ir1);
  458            while (k10 >= nk0) {
  459                k10 -= nk0;
  460                if (++i11 == ne1) {
  461                    i11 = 0;
  462                    if (++i12 == ne2) {
  463                        i12 = 0;
  464                        if (++i13 == ne3) {
  465                            i13 = 0;
  466                        }
  467                    }
  468                }
  469            }
  470        }
  471    }
  472}
  473
  474static void ggml_compute_forward_dup_from_q(
  475        const ggml_compute_params * params,
  476              ggml_tensor * dst) {
  477
  478    const ggml_tensor * src0 = dst->src[0];
  479    const ggml_tensor * src1 = dst->src[1];
  480
  481    GGML_TENSOR_BINARY_OP_LOCALS
  482
  483    const ggml_type type = src0->type;
  484    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  485
  486    size_t qk = ggml_blck_size(type);
  487    const int64_t nr = ggml_nelements(src1) / qk;
  488
  489    // destination must be contiguous in the first dimension
  490    GGML_ASSERT(nb10 == ggml_type_size(dst->type));
  491    // must either have first dimension large enough to hold a row, or fully contiguous
  492    GGML_ASSERT((ne10 % qk) == 0 || ggml_is_contiguous(dst));
  493
  494    const int ith = params->ith;
  495    const int nth = params->nth;
  496
  497    const int dr = (nr + nth - 1)/nth;
  498
  499    // row range for this thread
  500    const int ir0 = dr*ith;
  501    const int ir1 = MIN(ir0 + dr, nr);
  502
  503    for (int64_t ir = ir0; ir < ir1; ++ir) {
  504
  505        uint32_t i = ir * qk;
  506
  507        const int64_t i03 = i/(ne00 * ne01 * ne02);
  508        const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
  509        const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
  510        const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
  511        const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
  512
  513        const int64_t i13 = i/(ne10 * ne11 * ne12);
  514        const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
  515        const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
  516        const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
  517        const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
  518
  519        dequantize_row_q(
  520                (const void *) ((char *) src0->data + x_offset),
  521                     (float *) ((char *)  dst->data + dst_offset), qk);
  522    }
  523}
  524
  525void ggml_compute_forward_dup(
  526        const ggml_compute_params * params,
  527        ggml_tensor * dst) {
  528
  529    const ggml_tensor * src0 = dst->src[0];
  530
  531    if (src0->type == dst->type) {
  532        ggml_compute_forward_dup_bytes(params, dst);
  533        return;
  534    }
  535
  536    switch (src0->type) {
  537        case GGML_TYPE_F16:
  538            {
  539                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_fp16_t>(params, dst);
  540                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_fp16_t, ggml_bf16_t>(params, dst);
  541                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_fp16_t, float      >(params, dst);
  542                else ggml_compute_forward_dup_to_q<ggml_fp16_t>(params, dst);
  543            } break;
  544        case GGML_TYPE_BF16:
  545            {
  546                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_fp16_t>(params, dst);
  547                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<ggml_bf16_t, ggml_bf16_t>(params, dst);
  548                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<ggml_bf16_t, float      >(params, dst);
  549                else ggml_compute_forward_dup_to_q<ggml_bf16_t>(params, dst);
  550            } break;
  551        case GGML_TYPE_F32:
  552            {
  553                /**/ if (dst->type == GGML_TYPE_F16)  ggml_compute_forward_dup_flt<float, ggml_fp16_t>(params, dst);
  554                else if (dst->type == GGML_TYPE_BF16) ggml_compute_forward_dup_flt<float, ggml_bf16_t>(params, dst);
  555                else if (dst->type == GGML_TYPE_F32)  ggml_compute_forward_dup_flt<float, float      >(params, dst);
  556                else if (dst->type == GGML_TYPE_I32)  ggml_compute_forward_dup_flt<float, int32_t    >(params, dst);
  557                else ggml_compute_forward_dup_to_q<float>(params, dst);
  558            } break;
  559        case GGML_TYPE_I32:
  560            {
  561                if (dst->type == GGML_TYPE_F32) ggml_compute_forward_dup_flt<int32_t, float>(params, dst);
  562                else GGML_ABORT("not implemented");
  563            } break;
  564        default:
  565            {
  566                if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
  567                    ggml_compute_forward_dup_from_q(params, dst);
  568                    break;
  569                }
  570                GGML_ABORT("fatal error");
  571            }
  572    }
  573}
  574
  575// ggml_compute_forward_add
  576
  577static void ggml_compute_forward_add_q_f32(
  578        const ggml_compute_params * params,
  579        ggml_tensor * dst) {
  580
  581    const ggml_tensor * src0 = dst->src[0];
  582    const ggml_tensor * src1 = dst->src[1];
  583
  584    GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
  585
  586    const int nr  = ggml_nrows(src0);
  587
  588    GGML_TENSOR_BINARY_OP_LOCALS
  589
  590    const int ith = params->ith;
  591    const int nth = params->nth;
  592
  593    const ggml_type type = src0->type;
  594    const ggml_type dtype = dst->type;
  595    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  596    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dtype)->from_float;
  597
  598    // we don't support permuted src0 or src1
  599    GGML_ASSERT(nb00 == ggml_type_size(type));
  600    GGML_ASSERT(nb10 == sizeof(float));
  601
  602    // dst cannot be transposed or permuted
  603    GGML_ASSERT(nb0 <= nb1);
  604    GGML_ASSERT(nb1 <= nb2);
  605    GGML_ASSERT(nb2 <= nb3);
  606
  607    GGML_ASSERT(ggml_is_quantized(src0->type));
  608    GGML_ASSERT(src1->type == GGML_TYPE_F32);
  609
  610    // rows per thread
  611    const int dr = (nr + nth - 1)/nth;
  612
  613    // row range for this thread
  614    const int ir0 = dr*ith;
  615    const int ir1 = MIN(ir0 + dr, nr);
  616
  617    float * wdata = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
  618
  619    for (int ir = ir0; ir < ir1; ++ir) {
  620        // src0 indices
  621        const int i03 = ir/(ne02*ne01);
  622        const int i02 = (ir - i03*ne02*ne01)/ne01;
  623        const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
  624
  625        // src1 and dst are same shape as src0 => same indices
  626        const int i13 = i03;
  627        const int i12 = i02;
  628        const int i11 = i01;
  629
  630        const int i3 = i03;
  631        const int i2 = i02;
  632        const int i1 = i01;
  633
  634        void  * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
  635        float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
  636        void  * dst_row  = (void *) ((char *)  dst->data + ( i1*nb1  +  i2*nb2  +  i3*nb3));
  637
  638        assert(ne00 % 32 == 0);
  639
  640        // unquantize row from src0 to temp buffer
  641        dequantize_row_q(src0_row, wdata, ne00);
  642        // add src1
  643        ggml_vec_acc_f32(ne00, wdata, src1_row);
  644        // quantize row to dst
  645        if (quantize_row_q != NULL) {
  646            quantize_row_q(wdata, dst_row, ne00);
  647        } else {
  648            memcpy(dst_row, wdata, ne0*nb0);
  649        }
  650    }
  651}
  652
  653void ggml_compute_forward_add(
  654        const ggml_compute_params * params,
  655        ggml_tensor * dst) {
  656
  657    const ggml_tensor * src0 = dst->src[0];
  658
  659    switch (src0->type) {
  660        case GGML_TYPE_F32:
  661        case GGML_TYPE_F16:
  662        case GGML_TYPE_BF16:
  663            {
  664                ggml_compute_forward_add_non_quantized(params, dst);
  665            } break;
  666        case GGML_TYPE_Q4_0:
  667        case GGML_TYPE_Q4_1:
  668        case GGML_TYPE_Q5_0:
  669        case GGML_TYPE_Q5_1:
  670        case GGML_TYPE_Q8_0:
  671        case GGML_TYPE_MXFP4:
  672        case GGML_TYPE_Q2_K:
  673        case GGML_TYPE_Q3_K:
  674        case GGML_TYPE_Q4_K:
  675        case GGML_TYPE_Q5_K:
  676        case GGML_TYPE_Q6_K:
  677        case GGML_TYPE_TQ1_0:
  678        case GGML_TYPE_TQ2_0:
  679        case GGML_TYPE_IQ2_XXS:
  680        case GGML_TYPE_IQ2_XS:
  681        case GGML_TYPE_IQ3_XXS:
  682        case GGML_TYPE_IQ1_S:
  683        case GGML_TYPE_IQ1_M:
  684        case GGML_TYPE_IQ4_NL:
  685        case GGML_TYPE_IQ4_XS:
  686        case GGML_TYPE_IQ3_S:
  687        case GGML_TYPE_IQ2_S:
  688            {
  689                ggml_compute_forward_add_q_f32(params, dst);
  690            } break;
  691        default:
  692            {
  693                GGML_ABORT("fatal error");
  694            }
  695    }
  696}
  697
  698// ggml_compute_forward_add_id
  699
  700static void ggml_compute_forward_add_id_f32(
  701        const ggml_compute_params * params,
  702        ggml_tensor * dst) {
  703
  704    const ggml_tensor * src0 = dst->src[0];
  705    const ggml_tensor * src1 = dst->src[1];
  706    const ggml_tensor * src2 = dst->src[2];
  707
  708    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
  709    GGML_ASSERT(src0->type == GGML_TYPE_F32);
  710    GGML_ASSERT(src1->type == GGML_TYPE_F32);
  711    GGML_ASSERT(src2->type == GGML_TYPE_I32);
  712
  713    GGML_ASSERT(src0->nb[0] == sizeof(float));
  714    GGML_ASSERT(src1->nb[0] == sizeof(float));
  715
  716    const int ith = params->ith;
  717    const int nth = params->nth;
  718
  719    const int nr  = ggml_nrows(src0);
  720
  721    GGML_TENSOR_TERNARY_OP_LOCALS
  722
  723    GGML_ASSERT( nb0 == sizeof(float));
  724    GGML_ASSERT(nb10 == sizeof(float));
  725
  726    // rows per thread
  727    const int dr = (nr + nth - 1)/nth;
  728
  729    // row range for this thread
  730    const int ir0 = dr*ith;
  731    const int ir1 = MIN(ir0 + dr, nr);
  732
  733    for (int ir = ir0; ir < ir1; ++ir) {
  734        // src0 indices
  735        const int i3 = ir/(ne2*ne1);
  736        const int i2 = (ir - i3*ne2*ne1)/ne1;
  737        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  738
  739        // src1 indices
  740        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
  741
  742        GGML_ASSERT(i11 >= 0 && i11 < ne11);
  743
  744        ggml_vec_add_f32(ne0,
  745                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
  746                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  747                (float *) ((char *) src1->data + i11*nb11));
  748    }
  749}
  750
  751void ggml_compute_forward_add_id(
  752        const ggml_compute_params * params,
  753        ggml_tensor * dst) {
  754
  755    const ggml_tensor * src0 = dst->src[0];
  756
  757    switch (src0->type) {
  758        case GGML_TYPE_F32:
  759            {
  760                ggml_compute_forward_add_id_f32(params, dst);
  761            } break;
  762        default:
  763            {
  764                GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
  765            }
  766    }
  767}
  768
  769// ggml_compute_forward_add1
  770
  771static void ggml_compute_forward_add1_f32(
  772        const ggml_compute_params * params,
  773        ggml_tensor * dst) {
  774
  775    const ggml_tensor * src0 = dst->src[0];
  776    const ggml_tensor * src1 = dst->src[1];
  777
  778    GGML_ASSERT(ggml_are_same_shape(src0, dst));
  779    GGML_ASSERT(ggml_is_scalar(src1));
  780
  781    const int ith = params->ith;
  782    const int nth = params->nth;
  783
  784    const int nr  = ggml_nrows(src0);
  785
  786    GGML_TENSOR_UNARY_OP_LOCALS
  787
  788    GGML_ASSERT( nb0 == sizeof(float));
  789    GGML_ASSERT(nb00 == sizeof(float));
  790
  791    // rows per thread
  792    const int dr = (nr + nth - 1)/nth;
  793
  794    // row range for this thread
  795    const int ir0 = dr*ith;
  796    const int ir1 = MIN(ir0 + dr, nr);
  797
  798    for (int ir = ir0; ir < ir1; ++ir) {
  799        // src0 and dst are same shape => same indices
  800        const int i3 = ir/(ne2*ne1);
  801        const int i2 = (ir - i3*ne2*ne1)/ne1;
  802        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  803
  804#ifdef GGML_USE_ACCELERATE
  805        GGML_UNUSED(ggml_vec_add1_f32);
  806
  807        vDSP_vadd(
  808                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
  809                (float *) ((char *) src1->data), 0,
  810                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ), 1,
  811                ne0);
  812#else
  813        ggml_vec_add1_f32(ne0,
  814                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
  815                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
  816               *(float *) src1->data);
  817#endif
  818    }
  819}
  820
  821static void ggml_compute_forward_add1_f16_f32(
  822        const ggml_compute_params * params,
  823        ggml_tensor * dst) {
  824
  825    const ggml_tensor * src0 = dst->src[0];
  826    const ggml_tensor * src1 = dst->src[1];
  827
  828    GGML_ASSERT(ggml_are_same_shape(src0, dst));
  829    GGML_ASSERT(ggml_is_scalar(src1));
  830
  831    // scalar to add
  832    const float v = *(float *) src1->data;
  833
  834    const int ith = params->ith;
  835    const int nth = params->nth;
  836
  837    const int nr  = ggml_nrows(src0);
  838
  839    GGML_TENSOR_UNARY_OP_LOCALS
  840
  841    GGML_ASSERT(src0->type == GGML_TYPE_F16);
  842    GGML_ASSERT(src1->type == GGML_TYPE_F32);
  843    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
  844
  845    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  846    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  847
  848    // rows per thread
  849    const int dr = (nr + nth - 1)/nth;
  850
  851    // row range for this thread
  852    const int ir0 = dr*ith;
  853    const int ir1 = MIN(ir0 + dr, nr);
  854
  855    for (int ir = ir0; ir < ir1; ++ir) {
  856        // src0 and dst are same shape => same indices
  857        const int i3 = ir/(ne2*ne1);
  858        const int i2 = (ir - i3*ne2*ne1)/ne1;
  859        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  860
  861        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
  862        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  863        for (int i = 0; i < ne0; i++) {
  864            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
  865        }
  866    }
  867}
  868
  869static void ggml_compute_forward_add1_f16_f16(
  870        const ggml_compute_params * params,
  871        ggml_tensor * dst) {
  872
  873    const ggml_tensor * src0 = dst->src[0];
  874    const ggml_tensor * src1 = dst->src[1];
  875
  876    GGML_ASSERT(ggml_are_same_shape(src0, dst));
  877    GGML_ASSERT(ggml_is_scalar(src1));
  878
  879    // scalar to add
  880    const float v = GGML_CPU_FP16_TO_FP32(*(ggml_fp16_t *) src1->data);
  881
  882    const int ith = params->ith;
  883    const int nth = params->nth;
  884
  885    const int nr  = ggml_nrows(src0);
  886
  887    GGML_TENSOR_UNARY_OP_LOCALS
  888
  889    GGML_ASSERT(src0->type == GGML_TYPE_F16);
  890    GGML_ASSERT(src1->type == GGML_TYPE_F16);
  891    GGML_ASSERT(dst->type  == GGML_TYPE_F16);
  892
  893    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
  894    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
  895
  896    // rows per thread
  897    const int dr = (nr + nth - 1)/nth;
  898
  899    // row range for this thread
  900    const int ir0 = dr*ith;
  901    const int ir1 = MIN(ir0 + dr, nr);
  902
  903    for (int ir = ir0; ir < ir1; ++ir) {
  904        // src0 and dst are same shape => same indices
  905        const int i3 = ir/(ne2*ne1);
  906        const int i2 = (ir - i3*ne2*ne1)/ne1;
  907        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  908
  909        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
  910        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
  911        for (int i = 0; i < ne0; i++) {
  912            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(src0_ptr[i]) + v);
  913        }
  914    }
  915}
  916
  917static void ggml_compute_forward_add1_q_f32(
  918        const ggml_compute_params * params,
  919        ggml_tensor * dst) {
  920
  921    const ggml_tensor * src0 = dst->src[0];
  922    const ggml_tensor * src1 = dst->src[1];
  923
  924    GGML_ASSERT(ggml_are_same_shape(src0, dst));
  925    GGML_ASSERT(ggml_is_scalar(src1));
  926
  927    // scalar to add
  928    const float v = *(float *) src1->data;
  929
  930    const int ith = params->ith;
  931    const int nth = params->nth;
  932
  933    const int nr  = ggml_nrows(src0);
  934
  935    GGML_TENSOR_UNARY_OP_LOCALS
  936
  937    const ggml_type type = src0->type;
  938    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
  939    ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(type)->from_float;
  940
  941    // we don't support permuted src0
  942    GGML_ASSERT(nb00 == ggml_type_size(type));
  943
  944    // dst cannot be transposed or permuted
  945    GGML_ASSERT(nb0 <= nb1);
  946    GGML_ASSERT(nb1 <= nb2);
  947    GGML_ASSERT(nb2 <= nb3);
  948
  949    GGML_ASSERT(ggml_is_quantized(src0->type));
  950    GGML_ASSERT(dst->type == src0->type);
  951    GGML_ASSERT(src1->type == GGML_TYPE_F32);
  952
  953    // rows per thread
  954    const int dr = (nr + nth - 1)/nth;
  955
  956    // row range for this thread
  957    const int ir0 = dr*ith;
  958    const int ir1 = MIN(ir0 + dr, nr);
  959
  960    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
  961
  962    for (int ir = ir0; ir < ir1; ++ir) {
  963        // src0 and dst are same shape => same indices
  964        const int i3 = ir/(ne2*ne1);
  965        const int i2 = (ir - i3*ne2*ne1)/ne1;
  966        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
  967
  968        void  * src0_row = (void *) ((char *) src0->data + (i1*nb01 + i2*nb02 + i3*nb03));
  969        void  * dst_row  = (void *) ((char *)  dst->data + (i1*nb1  + i2*nb2  + i3*nb0 ));
  970
  971        assert(ne0 % 32 == 0);
  972
  973        // unquantize row from src0 to temp buffer
  974        dequantize_row_q(src0_row, wdata, ne0);
  975        // add src1
  976        ggml_vec_acc1_f32(ne0, wdata, v);
  977        // quantize row to dst
  978        quantize_row_q(wdata, dst_row, ne0);
  979    }
  980}
  981
  982static void ggml_compute_forward_add1_bf16_f32(
  983        const ggml_compute_params * params,
  984        ggml_tensor * dst) {
  985
  986    const ggml_tensor * src0 = dst->src[0];
  987    const ggml_tensor * src1 = dst->src[1];
  988
  989    GGML_ASSERT(ggml_are_same_shape(src0, dst));
  990    GGML_ASSERT(ggml_is_scalar(src1));
  991
  992    // scalar to add
  993    const float v = *(float *) src1->data;
  994
  995    const int ith = params->ith;
  996    const int nth = params->nth;
  997
  998    const int nr  = ggml_nrows(src0);
  999
 1000    GGML_TENSOR_UNARY_OP_LOCALS
 1001
 1002    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
 1003    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 1004    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
 1005
 1006    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
 1007    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
 1008
 1009    // rows per thread
 1010    const int dr = (nr + nth - 1)/nth;
 1011
 1012    // row range for this thread
 1013    const int ir0 = dr*ith;
 1014    const int ir1 = MIN(ir0 + dr, nr);
 1015
 1016    for (int ir = ir0; ir < ir1; ++ir) {
 1017        // src0 and dst are same shape => same indices
 1018        const int i3 = ir/(ne2*ne1);
 1019        const int i2 = (ir - i3*ne2*ne1)/ne1;
 1020        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 1021
 1022        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
 1023        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
 1024        for (int i = 0; i < ne0; i++) {
 1025            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
 1026        }
 1027    }
 1028}
 1029
 1030static void ggml_compute_forward_add1_bf16_bf16(
 1031        const ggml_compute_params * params,
 1032        ggml_tensor * dst) {
 1033
 1034    const ggml_tensor * src0 = dst->src[0];
 1035    const ggml_tensor * src1 = dst->src[1];
 1036
 1037    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 1038    GGML_ASSERT(ggml_is_scalar(src1));
 1039
 1040    // scalar to add
 1041    const float v = GGML_BF16_TO_FP32(*(ggml_bf16_t *) src1->data);
 1042
 1043    const int ith = params->ith;
 1044    const int nth = params->nth;
 1045
 1046    const int nr  = ggml_nrows(src0);
 1047
 1048    GGML_TENSOR_UNARY_OP_LOCALS
 1049
 1050    GGML_ASSERT(src0->type == GGML_TYPE_BF16);
 1051    GGML_ASSERT(src1->type == GGML_TYPE_BF16);
 1052    GGML_ASSERT(dst->type  == GGML_TYPE_BF16);
 1053
 1054    GGML_ASSERT( nb0 == sizeof(ggml_bf16_t));
 1055    GGML_ASSERT(nb00 == sizeof(ggml_bf16_t));
 1056
 1057    // rows per thread
 1058    const int dr = (nr + nth - 1)/nth;
 1059
 1060    // row range for this thread
 1061    const int ir0 = dr*ith;
 1062    const int ir1 = MIN(ir0 + dr, nr);
 1063
 1064    for (int ir = ir0; ir < ir1; ++ir) {
 1065        // src0 and dst are same shape => same indices
 1066        const int i3 = ir/(ne2*ne1);
 1067        const int i2 = (ir - i3*ne2*ne1)/ne1;
 1068        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
 1069
 1070        ggml_bf16_t * dst_ptr  = (ggml_bf16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 );
 1071        ggml_bf16_t * src0_ptr = (ggml_bf16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
 1072        for (int i = 0; i < ne0; i++) {
 1073            dst_ptr[i] = GGML_FP32_TO_BF16(GGML_BF16_TO_FP32(src0_ptr[i]) + v);
 1074        }
 1075    }
 1076}
 1077
 1078void ggml_compute_forward_add1(
 1079        const ggml_compute_params * params,
 1080        ggml_tensor * dst) {
 1081
 1082    const ggml_tensor * src0 = dst->src[0];
 1083    const ggml_tensor * src1 = dst->src[1];
 1084
 1085    switch (src0->type) {
 1086        case GGML_TYPE_F32:
 1087            {
 1088                ggml_compute_forward_add1_f32(params, dst);
 1089            } break;
 1090        case GGML_TYPE_F16:
 1091            {
 1092                if (src1->type == GGML_TYPE_F16) {
 1093                    ggml_compute_forward_add1_f16_f16(params, dst);
 1094                }
 1095                else if (src1->type == GGML_TYPE_F32) {
 1096                    ggml_compute_forward_add1_f16_f32(params, dst);
 1097                }
 1098                else {
 1099                    GGML_ABORT("fatal error");
 1100                }
 1101            } break;
 1102        case GGML_TYPE_BF16:
 1103            {
 1104                if (src1->type == GGML_TYPE_BF16) {
 1105                    ggml_compute_forward_add1_bf16_bf16(params, dst);
 1106                }
 1107                else if (src1->type == GGML_TYPE_F32) {
 1108                    ggml_compute_forward_add1_bf16_f32(params, dst);
 1109                }
 1110                else {
 1111                    GGML_ABORT("fatal error");
 1112                }
 1113            } break;
 1114        case GGML_TYPE_Q4_0:
 1115        case GGML_TYPE_Q4_1:
 1116        case GGML_TYPE_Q5_0:
 1117        case GGML_TYPE_Q5_1:
 1118        case GGML_TYPE_Q8_0:
 1119        case GGML_TYPE_Q8_1:
 1120        case GGML_TYPE_MXFP4:
 1121        case GGML_TYPE_Q2_K:
 1122        case GGML_TYPE_Q3_K:
 1123        case GGML_TYPE_Q4_K:
 1124        case GGML_TYPE_Q5_K:
 1125        case GGML_TYPE_Q6_K:
 1126        case GGML_TYPE_TQ1_0:
 1127        case GGML_TYPE_TQ2_0:
 1128        case GGML_TYPE_IQ2_XXS:
 1129        case GGML_TYPE_IQ2_XS:
 1130        case GGML_TYPE_IQ3_XXS:
 1131        case GGML_TYPE_IQ1_S:
 1132        case GGML_TYPE_IQ1_M:
 1133        case GGML_TYPE_IQ4_NL:
 1134        case GGML_TYPE_IQ4_XS:
 1135        case GGML_TYPE_IQ3_S:
 1136        case GGML_TYPE_IQ2_S:
 1137            {
 1138                ggml_compute_forward_add1_q_f32(params, dst);
 1139            } break;
 1140        default:
 1141            {
 1142                GGML_ABORT("fatal error");
 1143            }
 1144    }
 1145}
 1146
 1147// ggml_compute_forward_acc
 1148
 1149static void ggml_compute_forward_acc_f32(
 1150        const ggml_compute_params * params,
 1151        ggml_tensor * dst) {
 1152
 1153    const ggml_tensor * src0 = dst->src[0];
 1154    const ggml_tensor * src1 = dst->src[1];
 1155
 1156    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 1157    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 1158
 1159    // view src0 and dst with these strides and data offset inbytes during acc
 1160    // nb0 is implicitly element_size because src0 and dst are contiguous
 1161    size_t nb1     = ((int32_t *) dst->op_params)[0];
 1162    size_t nb2     = ((int32_t *) dst->op_params)[1];
 1163    size_t nb3     = ((int32_t *) dst->op_params)[2];
 1164    size_t offset  = ((int32_t *) dst->op_params)[3];
 1165    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 1166
 1167    if (!inplace) {
 1168        if (params->ith == 0) {
 1169            // memcpy needs to be synchronized across threads to avoid race conditions.
 1170            // => do it in INIT phase
 1171            memcpy(
 1172                ((char *)  dst->data),
 1173                ((char *) src0->data),
 1174                ggml_nbytes(dst));
 1175        }
 1176        ggml_barrier(params->threadpool);
 1177    }
 1178
 1179    const int ith = params->ith;
 1180    const int nth = params->nth;
 1181
 1182    const int nr = ggml_nrows(src1);
 1183    const int nc = src1->ne[0];
 1184
 1185    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
 1186    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
 1187
 1188    // src0 and dst as viewed during acc
 1189    const size_t nb0 = ggml_element_size(src0);
 1190
 1191    const size_t nb00 = nb0;
 1192    const size_t nb01 = nb1;
 1193    const size_t nb02 = nb2;
 1194    const size_t nb03 = nb3;
 1195
 1196    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb0  + (ne11 == 0 ? 0 : ne11-1)*nb1  + (ne12 == 0 ? 0 : ne12-1)*nb2  + (ne13 == 0 ? 0 : ne13-1)*nb3  < ggml_nbytes(dst));
 1197    GGML_ASSERT(offset + (ne10 == 0 ? 0 : ne10-1)*nb00 + (ne11 == 0 ? 0 : ne11-1)*nb01 + (ne12 == 0 ? 0 : ne12-1)*nb02 + (ne13 == 0 ? 0 : ne13-1)*nb03 < ggml_nbytes(src0));
 1198
 1199    GGML_ASSERT(nb10 == sizeof(float));
 1200
 1201    // rows per thread
 1202    const int dr = (nr + nth - 1)/nth;
 1203
 1204    // row range for this thread
 1205    const int ir0 = dr*ith;
 1206    const int ir1 = MIN(ir0 + dr, nr);
 1207
 1208    for (int ir = ir0; ir < ir1; ++ir) {
 1209        // src0 and dst are viewed with shape of src1 and offset
 1210        // => same indices
 1211        const int i3 = ir/(ne12*ne11);
 1212        const int i2 = (ir - i3*ne12*ne11)/ne11;
 1213        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
 1214
 1215#ifdef GGML_USE_ACCELERATE
 1216        vDSP_vadd(
 1217                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset), 1,
 1218                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
 1219                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1  + offset), 1, nc);
 1220#else
 1221        ggml_vec_add_f32(nc,
 1222                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
 1223                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + offset),
 1224                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
 1225#endif
 1226    }
 1227}
 1228
 1229void ggml_compute_forward_acc(
 1230        const ggml_compute_params * params,
 1231        ggml_tensor * dst) {
 1232
 1233    const ggml_tensor * src0 = dst->src[0];
 1234
 1235    switch (src0->type) {
 1236        case GGML_TYPE_F32:
 1237            {
 1238                ggml_compute_forward_acc_f32(params, dst);
 1239            } break;
 1240        case GGML_TYPE_F16:
 1241        case GGML_TYPE_BF16:
 1242        case GGML_TYPE_Q4_0:
 1243        case GGML_TYPE_Q4_1:
 1244        case GGML_TYPE_Q5_0:
 1245        case GGML_TYPE_Q5_1:
 1246        case GGML_TYPE_Q8_0:
 1247        case GGML_TYPE_Q8_1:
 1248        case GGML_TYPE_MXFP4:
 1249        case GGML_TYPE_Q2_K:
 1250        case GGML_TYPE_Q3_K:
 1251        case GGML_TYPE_Q4_K:
 1252        case GGML_TYPE_Q5_K:
 1253        case GGML_TYPE_Q6_K:
 1254        case GGML_TYPE_TQ1_0:
 1255        case GGML_TYPE_TQ2_0:
 1256        case GGML_TYPE_IQ2_XXS:
 1257        case GGML_TYPE_IQ2_XS:
 1258        case GGML_TYPE_IQ3_XXS:
 1259        case GGML_TYPE_IQ1_S:
 1260        case GGML_TYPE_IQ1_M:
 1261        case GGML_TYPE_IQ4_NL:
 1262        case GGML_TYPE_IQ4_XS:
 1263        case GGML_TYPE_IQ3_S:
 1264        case GGML_TYPE_IQ2_S:
 1265        default:
 1266            {
 1267                GGML_ABORT("fatal error");
 1268            }
 1269    }
 1270}
 1271
 1272// ggml_compute_forward_sum
 1273
 1274static void ggml_compute_forward_sum_f32(
 1275        const ggml_compute_params * params,
 1276        ggml_tensor * dst) {
 1277
 1278    const ggml_tensor * src0 = dst->src[0];
 1279
 1280    if (params->ith != 0) {
 1281        return;
 1282    }
 1283
 1284    assert(ggml_is_scalar(dst));
 1285    assert(src0->nb[0] == sizeof(float));
 1286
 1287    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 1288    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 1289
 1290    ggml_float sum     = 0;
 1291    ggml_float row_sum = 0;
 1292
 1293    for (int64_t i03 = 0; i03 < ne03; i03++) {
 1294        for (int64_t i02 = 0; i02 < ne02; i02++) {
 1295            for (int64_t i01 = 0; i01 < ne01; i01++) {
 1296                ggml_vec_sum_f32_ggf(ne00,
 1297                        &row_sum,
 1298                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
 1299                sum += row_sum;
 1300            }
 1301        }
 1302    }
 1303    ((float *) dst->data)[0] = sum;
 1304}
 1305
 1306static void ggml_compute_forward_sum_f16(
 1307    const ggml_compute_params * params,
 1308          ggml_tensor * dst) {
 1309
 1310    const ggml_tensor * src0 = dst->src[0];
 1311
 1312    if (params->ith != 0) {
 1313        return;
 1314    }
 1315
 1316    assert(ggml_is_scalar(dst));
 1317
 1318    assert(src0->nb[0] == sizeof(ggml_fp16_t));
 1319
 1320    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 1321    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 1322
 1323    float sum = 0;
 1324    float row_sum = 0;
 1325
 1326    for (int64_t i03 = 0; i03 < ne03; i03++) {
 1327        for (int64_t i02 = 0; i02 < ne02; i02++) {
 1328            for (int64_t i01 = 0; i01 < ne01; i01++) {
 1329                ggml_vec_sum_f16_ggf(ne00,
 1330                    &row_sum,
 1331                    (ggml_fp16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
 1332                sum += row_sum;
 1333            }
 1334        }
 1335    }
 1336    ((ggml_fp16_t *) dst->data)[0] = GGML_CPU_FP32_TO_FP16(sum);
 1337}
 1338
 1339static void ggml_compute_forward_sum_bf16(
 1340    const ggml_compute_params * params,
 1341          ggml_tensor * dst) {
 1342
 1343    const ggml_tensor * src0 = dst->src[0];
 1344
 1345    if (params->ith != 0) {
 1346        return;
 1347    }
 1348
 1349    assert(ggml_is_scalar(dst));
 1350
 1351    assert(src0->nb[0] == sizeof(ggml_bf16_t));
 1352
 1353    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 1354    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 1355
 1356    float sum = 0;
 1357    float row_sum = 0;
 1358
 1359    for (int64_t i03 = 0; i03 < ne03; i03++) {
 1360        for (int64_t i02 = 0; i02 < ne02; i02++) {
 1361            for (int64_t i01 = 0; i01 < ne01; i01++) {
 1362                ggml_vec_sum_bf16_ggf(ne00,
 1363                    &row_sum,
 1364                    (ggml_bf16_t *) ((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03));
 1365                sum += row_sum;
 1366            }
 1367        }
 1368    }
 1369    ((ggml_bf16_t *) dst->data)[0] = GGML_FP32_TO_BF16(sum);
 1370}
 1371
 1372void ggml_compute_forward_sum(
 1373        const ggml_compute_params * params,
 1374        ggml_tensor * dst) {
 1375
 1376    const ggml_tensor * src0 = dst->src[0];
 1377
 1378    switch (src0->type) {
 1379        case GGML_TYPE_F32:
 1380            {
 1381                ggml_compute_forward_sum_f32(params, dst);
 1382            } break;
 1383        case GGML_TYPE_F16:
 1384            {
 1385                ggml_compute_forward_sum_f16(params, dst);
 1386            } break;
 1387        case GGML_TYPE_BF16:
 1388            {
 1389                ggml_compute_forward_sum_bf16(params, dst);
 1390            } break;
 1391        default:
 1392            {
 1393                GGML_ABORT("fatal error");
 1394            }
 1395    }
 1396}
 1397
 1398// ggml_compute_forward_cumsum
 1399
 1400static void ggml_compute_forward_cumsum_f32(
 1401        const ggml_compute_params * params,
 1402        ggml_tensor * dst) {
 1403
 1404    const ggml_tensor * src0 = dst->src[0];
 1405
 1406    GGML_ASSERT(src0->nb[0] == sizeof(float));
 1407    GGML_ASSERT(dst->nb[0] == sizeof(float));
 1408
 1409    GGML_TENSOR_UNARY_OP_LOCALS
 1410
 1411    GGML_ASSERT(ne0 == ne00);
 1412    GGML_ASSERT(ne1 == ne01);
 1413    GGML_ASSERT(ne2 == ne02);
 1414    GGML_ASSERT(ne3 == ne03);
 1415
 1416    const auto [ir0, ir1] = get_thread_range(params, src0);
 1417
 1418    for (int64_t ir = ir0; ir < ir1; ++ir) {
 1419        const int64_t i03 = ir/(ne02*ne01);
 1420        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
 1421        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 1422
 1423        float * src_row = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 1424        float * dst_row = (float *) ((char *) dst->data  + i01*nb1  + i02*nb2  + i03*nb3);
 1425
 1426        ggml_vec_cumsum_f32(ne00, dst_row, src_row);
 1427    }
 1428}
 1429
 1430void ggml_compute_forward_cumsum(
 1431        const ggml_compute_params * params,
 1432        ggml_tensor * dst) {
 1433
 1434    const ggml_tensor * src0 = dst->src[0];
 1435
 1436    switch (src0->type) {
 1437        case GGML_TYPE_F32:
 1438            {
 1439                ggml_compute_forward_cumsum_f32(params, dst);
 1440            } break;
 1441        default:
 1442            {
 1443                GGML_ABORT("fatal error");
 1444            }
 1445    }
 1446}
 1447
 1448// ggml_compute_forward_sum_rows
 1449
 1450static void ggml_compute_forward_sum_rows_f32(
 1451        const ggml_compute_params * params,
 1452        ggml_tensor * dst) {
 1453
 1454    const ggml_tensor * src0 = dst->src[0];
 1455
 1456    if (params->ith != 0) {
 1457        return;
 1458    }
 1459
 1460    GGML_ASSERT(src0->nb[0] == sizeof(float));
 1461    GGML_ASSERT(dst->nb[0] == sizeof(float));
 1462
 1463    GGML_TENSOR_UNARY_OP_LOCALS
 1464
 1465    GGML_ASSERT(ne0 == 1);
 1466    GGML_ASSERT(ne1 == ne01);
 1467    GGML_ASSERT(ne2 == ne02);
 1468    GGML_ASSERT(ne3 == ne03);
 1469
 1470    for (int64_t i3 = 0; i3 < ne03; i3++) {
 1471        for (int64_t i2 = 0; i2 < ne02; i2++) {
 1472            for (int64_t i1 = 0; i1 < ne01; i1++) {
 1473                float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
 1474                float * dst_row = (float *) ((char *) dst->data  + i1*nb1  + i2*nb2  + i3*nb3);
 1475                float row_sum = 0;
 1476                ggml_vec_sum_f32(ne00, &row_sum, src_row);
 1477                dst_row[0] = row_sum;
 1478            }
 1479        }
 1480    }
 1481}
 1482
 1483void ggml_compute_forward_sum_rows(
 1484        const ggml_compute_params * params,
 1485        ggml_tensor * dst) {
 1486
 1487    const ggml_tensor * src0 = dst->src[0];
 1488
 1489    switch (src0->type) {
 1490        case GGML_TYPE_F32:
 1491            {
 1492                ggml_compute_forward_sum_rows_f32(params, dst);
 1493            } break;
 1494        default:
 1495            {
 1496                GGML_ABORT("fatal error");
 1497            }
 1498    }
 1499}
 1500
 1501// ggml_compute_forward_mean
 1502
 1503static void ggml_compute_forward_mean_f32(
 1504        const ggml_compute_params * params,
 1505        ggml_tensor * dst) {
 1506
 1507    const ggml_tensor * src0 = dst->src[0];
 1508
 1509    if (params->ith != 0) {
 1510        return;
 1511    }
 1512
 1513    assert(src0->nb[0] == sizeof(float));
 1514
 1515    GGML_TENSOR_UNARY_OP_LOCALS
 1516
 1517    assert(ne0 == 1);
 1518    assert(ne1 == ne01);
 1519    assert(ne2 == ne02);
 1520    assert(ne3 == ne03);
 1521
 1522    GGML_UNUSED(ne0);
 1523    GGML_UNUSED(ne1);
 1524    GGML_UNUSED(ne2);
 1525    GGML_UNUSED(ne3);
 1526
 1527    for (int64_t i03 = 0; i03 < ne03; i03++) {
 1528        for (int64_t i02 = 0; i02 < ne02; i02++) {
 1529            for (int64_t i01 = 0; i01 < ne01; i01++) {
 1530                ggml_vec_sum_f32(ne00,
 1531                        (float *) ((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3),
 1532                        (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03));
 1533
 1534                *(float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3) /= (float) ne00;
 1535            }
 1536        }
 1537    }
 1538}
 1539
 1540void ggml_compute_forward_mean(
 1541        const ggml_compute_params * params,
 1542        ggml_tensor * dst) {
 1543
 1544    const ggml_tensor * src0 = dst->src[0];
 1545
 1546    switch (src0->type) {
 1547        case GGML_TYPE_F32:
 1548            {
 1549                ggml_compute_forward_mean_f32(params, dst);
 1550            } break;
 1551        default:
 1552            {
 1553                GGML_ABORT("fatal error");
 1554            }
 1555    }
 1556}
 1557
 1558// ggml_compute_forward_argmax
 1559
 1560static void ggml_compute_forward_argmax_f32(
 1561        const ggml_compute_params * params,
 1562        ggml_tensor * dst) {
 1563
 1564    const ggml_tensor * src0 = dst->src[0];
 1565
 1566    if (params->ith != 0) {
 1567        return;
 1568    }
 1569
 1570    assert(src0->nb[0] == sizeof(float));
 1571    assert(dst->nb[0] == sizeof(float));
 1572
 1573    const int64_t ne00 = src0->ne[0];
 1574    const int64_t ne01 = src0->ne[1];
 1575
 1576    const size_t nb01 = src0->nb[1];
 1577    const size_t nb0 = dst->nb[0];
 1578
 1579    for (int64_t i1 = 0; i1 < ne01; i1++) {
 1580        float * src = (float *) ((char *) src0->data + i1*nb01);
 1581        int32_t * dst_ = (int32_t *) ((char *)  dst->data + i1*nb0);
 1582        int v = 0;
 1583        ggml_vec_argmax_f32(ne00, &v, src);
 1584        dst_[0] = v;
 1585    }
 1586}
 1587
 1588void ggml_compute_forward_argmax(
 1589        const ggml_compute_params * params,
 1590        ggml_tensor * dst) {
 1591
 1592    const ggml_tensor * src0 = dst->src[0];
 1593
 1594    switch (src0->type) {
 1595        case GGML_TYPE_F32:
 1596            {
 1597                ggml_compute_forward_argmax_f32(params, dst);
 1598            } break;
 1599        default:
 1600            {
 1601                GGML_ABORT("fatal error");
 1602            }
 1603    }
 1604}
 1605
 1606// ggml_compute_forward_count_equal
 1607
 1608static void ggml_compute_forward_count_equal_i32(
 1609        const ggml_compute_params * params,
 1610        ggml_tensor * dst) {
 1611
 1612    const ggml_tensor * src0 = dst->src[0];
 1613    const ggml_tensor * src1 = dst->src[1];
 1614
 1615    GGML_TENSOR_BINARY_OP_LOCALS;
 1616
 1617    GGML_ASSERT(src0->type == GGML_TYPE_I32);
 1618    GGML_ASSERT(src1->type == GGML_TYPE_I32);
 1619    GGML_ASSERT(ggml_are_same_shape(src0, src1));
 1620    GGML_ASSERT(ggml_is_scalar(dst));
 1621    GGML_ASSERT(dst->type == GGML_TYPE_I64);
 1622
 1623    const int64_t nr = ggml_nrows(src0);
 1624
 1625    const int ith = params->ith;
 1626    const int nth = params->nth;
 1627
 1628    int64_t * sums = (int64_t *) params->wdata;
 1629    int64_t sum_thread = 0;
 1630
 1631    // rows per thread
 1632    const int64_t dr = (nr + nth - 1)/nth;
 1633
 1634    // row range for this thread
 1635    const int64_t ir0 = dr*ith;
 1636    const int64_t ir1 = MIN(ir0 + dr, nr);
 1637
 1638    for (int64_t ir = ir0; ir < ir1; ++ir) {
 1639        const int64_t i03 =  ir                        / (ne02*ne01);
 1640        const int64_t i02 = (ir - i03*ne03)            /       ne01;
 1641        const int64_t i01 =  ir - i03*ne03 - i02*ne02;
 1642
 1643        const char * data0 = (const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01;
 1644        const char * data1 = (const char *) src1->data + i03*nb13 + i02*nb12 + i01*nb11;
 1645
 1646        for (int64_t i00 = 0; i00 < ne00; ++i00) {
 1647            const int32_t val0 = *((const int32_t *) (data0 + i00*nb00));
 1648            const int32_t val1 = *((const int32_t *) (data1 + i00*nb10));
 1649
 1650            sum_thread += val0 == val1;
 1651        }
 1652    }
 1653    if (ith != 0) {
 1654        sums[ith] = sum_thread;
 1655    }
 1656    ggml_barrier(params->threadpool);
 1657
 1658    if (ith != 0) {
 1659        return;
 1660    }
 1661
 1662    for (int ith_other = 1; ith_other < nth; ++ith_other) {
 1663        sum_thread += sums[ith_other];
 1664    }
 1665    *((int64_t *) dst->data) = sum_thread;
 1666}
 1667
 1668void ggml_compute_forward_count_equal(
 1669        const ggml_compute_params * params,
 1670        ggml_tensor * dst) {
 1671
 1672    const ggml_tensor * src0 = dst->src[0];
 1673
 1674    switch (src0->type) {
 1675        case GGML_TYPE_I32:
 1676            {
 1677                ggml_compute_forward_count_equal_i32(params, dst);
 1678            } break;
 1679        default:
 1680            {
 1681                GGML_ABORT("fatal error");
 1682            }
 1683    }
 1684}
 1685
 1686// ggml_compute_forward_repeat
 1687
 1688static void ggml_compute_forward_repeat_f32(
 1689        const ggml_compute_params * params,
 1690        ggml_tensor * dst) {
 1691
 1692    const ggml_tensor * src0 = dst->src[0];
 1693
 1694    if (params->ith != 0) {
 1695        return;
 1696    }
 1697
 1698    GGML_ASSERT(ggml_can_repeat(src0, dst));
 1699
 1700    GGML_TENSOR_UNARY_OP_LOCALS
 1701
 1702    // guaranteed to be an integer due to the check in ggml_can_repeat
 1703    const int nr0 = (int)(ne0/ne00);
 1704    const int nr1 = (int)(ne1/ne01);
 1705    const int nr2 = (int)(ne2/ne02);
 1706    const int nr3 = (int)(ne3/ne03);
 1707
 1708    // TODO: support for transposed / permuted tensors
 1709    GGML_ASSERT(nb0  == sizeof(float));
 1710    GGML_ASSERT(nb00 == sizeof(float));
 1711
 1712    // TODO: maybe this is not optimal?
 1713    for                         (int i3 = 0; i3 < nr3;  i3++) {
 1714        for                     (int k3 = 0; k3 < ne03; k3++) {
 1715            for                 (int i2 = 0; i2 < nr2;  i2++) {
 1716                for             (int k2 = 0; k2 < ne02; k2++) {
 1717                    for         (int i1 = 0; i1 < nr1;  i1++) {
 1718                        for     (int k1 = 0; k1 < ne01; k1++) {
 1719                            for (int i0 = 0; i0 < nr0;  i0++) {
 1720                                ggml_vec_cpy_f32(ne00,
 1721                                        (float *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0),
 1722                                        (float *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01));
 1723                            }
 1724                        }
 1725                    }
 1726                }
 1727            }
 1728        }
 1729    }
 1730}
 1731
 1732static void ggml_compute_forward_repeat_f16(
 1733        const ggml_compute_params * params,
 1734        ggml_tensor * dst) {
 1735
 1736    const ggml_tensor * src0 = dst->src[0];
 1737
 1738    if (params->ith != 0) {
 1739        return;
 1740    }
 1741
 1742    GGML_ASSERT(ggml_can_repeat(src0, dst));
 1743
 1744    GGML_TENSOR_UNARY_OP_LOCALS
 1745
 1746    // guaranteed to be an integer due to the check in ggml_can_repeat
 1747    const int nr0 = (int)(ne0/ne00);
 1748    const int nr1 = (int)(ne1/ne01);
 1749    const int nr2 = (int)(ne2/ne02);
 1750    const int nr3 = (int)(ne3/ne03);
 1751
 1752    // TODO: support for transposed / permuted tensors
 1753    GGML_ASSERT(nb0  == sizeof(ggml_fp16_t));
 1754    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 1755
 1756    // TODO: maybe this is not optimal?
 1757    for                         (int i3 = 0; i3 < nr3;  i3++) {
 1758        for                     (int k3 = 0; k3 < ne03; k3++) {
 1759            for                 (int i2 = 0; i2 < nr2;  i2++) {
 1760                for             (int k2 = 0; k2 < ne02; k2++) {
 1761                    for         (int i1 = 0; i1 < nr1;  i1++) {
 1762                        for     (int k1 = 0; k1 < ne01; k1++) {
 1763                            for (int i0 = 0; i0 < nr0;  i0++) {
 1764                                ggml_fp16_t * y = (ggml_fp16_t *) ((char *)  dst->data + (i3*ne03 + k3)*nb3  + (i2*ne02 + k2)*nb2  + (i1*ne01 + k1)*nb1  + (i0*ne00)*nb0);
 1765                                ggml_fp16_t * x = (ggml_fp16_t *) ((char *) src0->data + (          k3)*nb03 + (          k2)*nb02 + (          k1)*nb01);
 1766                                // ggml_vec_cpy_f16(ne00, y, x)
 1767                                for (int i = 0; i < ne00; ++i) {
 1768                                    y[i]  = x[i];
 1769                                }
 1770                            }
 1771                        }
 1772                    }
 1773                }
 1774            }
 1775        }
 1776    }
 1777}
 1778
 1779void ggml_compute_forward_repeat(
 1780        const ggml_compute_params * params,
 1781        ggml_tensor * dst) {
 1782
 1783    const ggml_tensor * src0 = dst->src[0];
 1784
 1785    switch (src0->type) {
 1786        case GGML_TYPE_F16:
 1787        case GGML_TYPE_BF16:
 1788        case GGML_TYPE_I16:
 1789            {
 1790                ggml_compute_forward_repeat_f16(params, dst);
 1791            } break;
 1792        case GGML_TYPE_F32:
 1793        case GGML_TYPE_I32:
 1794            {
 1795                ggml_compute_forward_repeat_f32(params, dst);
 1796            } break;
 1797        // TODO: templateify the implemenation and support for I64
 1798        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
 1799        //case GGML_TYPE_I64:
 1800        //    {
 1801        //        ggml_compute_forward_repeat_i64(params, dst);
 1802        //    } break;
 1803        default:
 1804            {
 1805                GGML_ABORT("fatal error");
 1806            }
 1807    }
 1808}
 1809
 1810// ggml_compute_forward_repeat_back
 1811
 1812static void ggml_compute_forward_repeat_back_f32(
 1813        const ggml_compute_params * params,
 1814        ggml_tensor * dst) {
 1815
 1816    const ggml_tensor * src0 = dst->src[0];
 1817
 1818    if (params->ith != 0) {
 1819        return;
 1820    }
 1821
 1822    GGML_ASSERT(ggml_can_repeat(dst, src0));
 1823
 1824    GGML_TENSOR_UNARY_OP_LOCALS
 1825
 1826    // guaranteed to be an integer due to the check in ggml_can_repeat
 1827    const int nr0 = (int)(ne00/ne0);
 1828    const int nr1 = (int)(ne01/ne1);
 1829    const int nr2 = (int)(ne02/ne2);
 1830    const int nr3 = (int)(ne03/ne3);
 1831
 1832    // TODO: support for transposed / permuted tensors
 1833    GGML_ASSERT(nb0  == sizeof(float));
 1834    GGML_ASSERT(nb00 == sizeof(float));
 1835
 1836    if (ggml_is_contiguous(dst)) {
 1837        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
 1838    } else {
 1839        for         (int k3 = 0; k3 < ne3; k3++) {
 1840            for     (int k2 = 0; k2 < ne2; k2++) {
 1841                for (int k1 = 0; k1 < ne1; k1++) {
 1842                    ggml_vec_set_f32(ne0,
 1843                        (float *) ((char *) dst->data + k1*nb1 + k2*nb2 + k3*nb3),
 1844                        0);
 1845                }
 1846            }
 1847        }
 1848    }
 1849
 1850    // TODO: maybe this is not optimal?
 1851    for                         (int i3 = 0; i3 < nr3; i3++) {
 1852        for                     (int k3 = 0; k3 < ne3; k3++) {
 1853            for                 (int i2 = 0; i2 < nr2; i2++) {
 1854                for             (int k2 = 0; k2 < ne2; k2++) {
 1855                    for         (int i1 = 0; i1 < nr1; i1++) {
 1856                        for     (int k1 = 0; k1 < ne1; k1++) {
 1857                            for (int i0 = 0; i0 < nr0; i0++) {
 1858                                ggml_vec_acc_f32(ne0,
 1859                                        (float *) ((char *)  dst->data + (         k3)*nb3  + (         k2)*nb2  + (         k1)*nb1),
 1860                                        (float *) ((char *) src0->data + (i3*ne3 + k3)*nb03 + (i2*ne2 + k2)*nb02 + (i1*ne1 + k1)*nb01 + (i0*ne0)*nb00));
 1861                            }
 1862                        }
 1863                    }
 1864                }
 1865            }
 1866        }
 1867    }
 1868}
 1869
 1870void ggml_compute_forward_repeat_back(
 1871        const ggml_compute_params * params,
 1872        ggml_tensor * dst) {
 1873
 1874    const ggml_tensor * src0 = dst->src[0];
 1875
 1876    switch (src0->type) {
 1877        case GGML_TYPE_F32:
 1878            {
 1879                ggml_compute_forward_repeat_back_f32(params, dst);
 1880            } break;
 1881        default:
 1882            {
 1883                GGML_ABORT("fatal error");
 1884            }
 1885    }
 1886}
 1887
 1888// ggml_compute_forward_concat
 1889
 1890static void ggml_compute_forward_concat_any(
 1891    const ggml_compute_params * params,
 1892    ggml_tensor * dst) {
 1893
 1894    const ggml_tensor * src0 = dst->src[0];
 1895    const ggml_tensor * src1 = dst->src[1];
 1896
 1897    const size_t len = ggml_type_size(src0->type);
 1898
 1899    const int ith = params->ith;
 1900    const int nth = params->nth;
 1901
 1902    GGML_TENSOR_BINARY_OP_LOCALS
 1903
 1904    const int32_t dim = ggml_get_op_params_i32(dst, 0);
 1905
 1906    GGML_ASSERT(dim >= 0 && dim < 4);
 1907
 1908    int64_t o[4] = {0, 0, 0, 0};
 1909    o[dim] = src0->ne[dim];
 1910
 1911    const char * x;
 1912
 1913    // TODO: smarter multi-theading
 1914    for (int i3 = 0; i3 < ne3; i3++) {
 1915        for (int i2 = ith; i2 < ne2; i2 += nth) {
 1916            for (int i1 = 0; i1 < ne1; i1++) {
 1917                for (int i0 = 0; i0 < ne0; i0++) {
 1918                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
 1919                        x = (const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03;
 1920                    } else {
 1921                        x = (const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13;
 1922                    }
 1923
 1924                    char * y = (char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3;
 1925
 1926                    memcpy(y, x, len);
 1927                }
 1928            }
 1929        }
 1930    }
 1931}
 1932
 1933static void ggml_compute_forward_concat_i8(
 1934    const ggml_compute_params * params,
 1935    ggml_tensor * dst) {
 1936
 1937    const ggml_tensor * src0 = dst->src[0];
 1938    const ggml_tensor * src1 = dst->src[1];
 1939
 1940    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(int8_t));
 1941
 1942    const int ith = params->ith;
 1943    const int nth = params->nth;
 1944
 1945    GGML_TENSOR_BINARY_OP_LOCALS
 1946
 1947    const int32_t dim = ggml_get_op_params_i32(dst, 0);
 1948
 1949    GGML_ASSERT(dim >= 0 && dim < 4);
 1950
 1951    int64_t o[4] = {0, 0, 0, 0};
 1952    o[dim] = src0->ne[dim];
 1953
 1954    const int8_t * x;
 1955
 1956    // TODO: smarter multi-theading
 1957    for (int i3 = 0; i3 < ne3; i3++) {
 1958        for (int i2 = ith; i2 < ne2; i2 += nth) {
 1959            for (int i1 = 0; i1 < ne1; i1++) {
 1960                for (int i0 = 0; i0 < ne0; i0++) {
 1961                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
 1962                        x = (const int8_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
 1963                    } else {
 1964                        x = (const int8_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
 1965                    }
 1966
 1967                    int8_t * y = (int8_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
 1968
 1969                    *y = *x;
 1970                }
 1971            }
 1972        }
 1973    }
 1974}
 1975
 1976static void ggml_compute_forward_concat_f16(
 1977    const ggml_compute_params * params,
 1978    ggml_tensor * dst) {
 1979
 1980    const ggml_tensor * src0 = dst->src[0];
 1981    const ggml_tensor * src1 = dst->src[1];
 1982
 1983    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(ggml_fp16_t));
 1984
 1985    const int ith = params->ith;
 1986    const int nth = params->nth;
 1987
 1988    GGML_TENSOR_BINARY_OP_LOCALS
 1989
 1990    const int32_t dim = ggml_get_op_params_i32(dst, 0);
 1991
 1992    GGML_ASSERT(dim >= 0 && dim < 4);
 1993
 1994    int64_t o[4] = {0, 0, 0, 0};
 1995    o[dim] = src0->ne[dim];
 1996
 1997    const ggml_fp16_t * x;
 1998
 1999    // TODO: smarter multi-theading
 2000    for (int i3 = 0; i3 < ne3; i3++) {
 2001        for (int i2 = ith; i2 < ne2; i2 += nth) {
 2002            for (int i1 = 0; i1 < ne1; i1++) {
 2003                for (int i0 = 0; i0 < ne0; i0++) {
 2004                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
 2005                        x = (const ggml_fp16_t *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
 2006                    } else {
 2007                        x = (const ggml_fp16_t *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
 2008                    }
 2009
 2010                    ggml_fp16_t * y = (ggml_fp16_t *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
 2011
 2012                    *y = *x;
 2013                }
 2014            }
 2015        }
 2016    }
 2017}
 2018
 2019static void ggml_compute_forward_concat_f32(
 2020    const ggml_compute_params * params,
 2021    ggml_tensor * dst) {
 2022
 2023    const ggml_tensor * src0 = dst->src[0];
 2024    const ggml_tensor * src1 = dst->src[1];
 2025
 2026    GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float));
 2027
 2028    const int ith = params->ith;
 2029    const int nth = params->nth;
 2030
 2031    GGML_TENSOR_BINARY_OP_LOCALS
 2032
 2033    const int32_t dim = ggml_get_op_params_i32(dst, 0);
 2034
 2035    GGML_ASSERT(dim >= 0 && dim < 4);
 2036
 2037    int64_t o[4] = {0, 0, 0, 0};
 2038    o[dim] = src0->ne[dim];
 2039
 2040    const float * x;
 2041
 2042    // TODO: smarter multi-theading
 2043    for (int i3 = 0; i3 < ne3; i3++) {
 2044        for (int i2 = ith; i2 < ne2; i2 += nth) {
 2045            for (int i1 = 0; i1 < ne1; i1++) {
 2046                for (int i0 = 0; i0 < ne0; i0++) {
 2047                    if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
 2048                        x = (const float *) ((const char *)src0->data + (i0       )*nb00 + (i1       )*nb01 + (i2       )*nb02 + (i3       )*nb03);
 2049                    } else {
 2050                        x = (const float *) ((const char *)src1->data + (i0 - o[0])*nb10 + (i1 - o[1])*nb11 + (i2 - o[2])*nb12 + (i3 - o[3])*nb13);
 2051                    }
 2052
 2053                    float * y = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
 2054
 2055                    *y = *x;
 2056                }
 2057            }
 2058        }
 2059    }
 2060}
 2061
 2062void ggml_compute_forward_concat(
 2063    const ggml_compute_params * params,
 2064    ggml_tensor * dst) {
 2065
 2066    const ggml_tensor * src0 = dst->src[0];
 2067
 2068    switch (src0->type) {
 2069        case GGML_TYPE_F16:
 2070        case GGML_TYPE_BF16:
 2071        case GGML_TYPE_I16:
 2072            {
 2073                ggml_compute_forward_concat_f16(params, dst);
 2074            } break;
 2075        case GGML_TYPE_I8:
 2076            {
 2077                ggml_compute_forward_concat_i8(params, dst);
 2078            } break;
 2079        case GGML_TYPE_F32:
 2080        case GGML_TYPE_I32:
 2081            {
 2082                ggml_compute_forward_concat_f32(params, dst);
 2083            } break;
 2084        default:
 2085            {
 2086                ggml_compute_forward_concat_any(params, dst);
 2087            }
 2088    }
 2089}
 2090
 2091// ggml_compute_forward_gelu
 2092
 2093static void ggml_compute_forward_gelu_f32(
 2094        const ggml_compute_params * params,
 2095        ggml_tensor * dst) {
 2096
 2097    const ggml_tensor * src0 = dst->src[0];
 2098
 2099    assert(ggml_is_contiguous_rows(src0));
 2100    assert(ggml_are_same_shape(src0, dst));
 2101
 2102    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2103    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2104    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2105    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2106
 2107    const int ith = params->ith;
 2108    const int nth = params->nth;
 2109
 2110    const int nc = src0->ne[0];
 2111    const int nr = ggml_nrows(src0);
 2112
 2113    // rows per thread
 2114    const int dr = (nr + nth - 1)/nth;
 2115
 2116    // row range for this thread
 2117    const int ir0 = dr*ith;
 2118    const int ir1 = MIN(ir0 + dr, nr);
 2119
 2120    for (int ir = ir0; ir < ir1; ++ir) {
 2121        const int i3 = ir/(ne02*ne01);
 2122        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2123        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2124
 2125        ggml_vec_gelu_f32(nc,
 2126                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2127                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2128
 2129#ifndef NDEBUG
 2130        for (int k = 0; k < nc; k++) {
 2131            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2132            GGML_UNUSED(x);
 2133            assert(!isnan(x));
 2134            assert(!isinf(x));
 2135        }
 2136#endif
 2137    }
 2138}
 2139
 2140static void ggml_compute_forward_gelu_f16(
 2141    const ggml_compute_params * params,
 2142    ggml_tensor * dst) {
 2143
 2144    const ggml_tensor * src0 = dst->src[0];
 2145
 2146    assert(ggml_is_contiguous_rows(src0));
 2147    assert(ggml_are_same_shape(src0, dst));
 2148
 2149    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2150    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2151    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2152    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2153
 2154    const int ith = params->ith;
 2155    const int nth = params->nth;
 2156
 2157    const int nc = src0->ne[0];
 2158    const int nr = ggml_nrows(src0);
 2159
 2160    // rows per thread
 2161    const int dr = (nr + nth - 1)/nth;
 2162
 2163    // row range for this thread
 2164    const int ir0 = dr*ith;
 2165    const int ir1 = MIN(ir0 + dr, nr);
 2166
 2167    for (int ir = ir0; ir < ir1; ++ir) {
 2168        const int i3 = ir/(ne02*ne01);
 2169        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2170        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2171
 2172        ggml_vec_gelu_f16(nc,
 2173                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2174                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2175
 2176#ifndef NDEBUG
 2177        for (int k = 0; k < nc; k++) {
 2178            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2179            const float v = GGML_CPU_FP16_TO_FP32(x);
 2180            GGML_UNUSED(v);
 2181            assert(!isnan(v));
 2182            assert(!isinf(v));
 2183        }
 2184#endif
 2185    }
 2186}
 2187
 2188static void ggml_compute_forward_gelu(
 2189        const ggml_compute_params * params,
 2190        ggml_tensor * dst) {
 2191
 2192    const ggml_tensor * src0 = dst->src[0];
 2193
 2194    switch (src0->type) {
 2195        case GGML_TYPE_F32:
 2196            {
 2197                ggml_compute_forward_gelu_f32(params, dst);
 2198            } break;
 2199        case GGML_TYPE_F16:
 2200            {
 2201                ggml_compute_forward_gelu_f16(params, dst);
 2202            } break;
 2203        default:
 2204            {
 2205                GGML_ABORT("fatal error");
 2206            }
 2207    }
 2208}
 2209
 2210// ggml_compute_fill
 2211
 2212static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, ggml_tensor * dst) {
 2213    const float c = ggml_get_op_params_f32(dst, 0);
 2214
 2215    GGML_TENSOR_LOCALS(int64_t, ne, dst, ne);
 2216    GGML_TENSOR_LOCALS(size_t,  nb, dst, nb);
 2217
 2218    const auto [ir0, ir1] = get_thread_range(params, dst);
 2219
 2220    for (int64_t ir = ir0; ir < ir1; ++ir) {
 2221        const int64_t i03 = ir/(ne2*ne1);
 2222        const int64_t i02 = (ir - i03*ne2*ne1)/ne1;
 2223        const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1);
 2224
 2225        float * dst_ptr  = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1);
 2226
 2227        ggml_vec_set_f32(ne0, dst_ptr, c);
 2228    }
 2229}
 2230
 2231void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) {
 2232    ggml_compute_forward_fill_f32(params, dst);
 2233}
 2234
 2235// ggml_compute_tri
 2236
 2237static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggml_tensor * dst) {
 2238    const ggml_tensor * src0 = dst->src[0];
 2239
 2240    const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0);
 2241
 2242    GGML_ASSERT(ggml_is_contiguous(src0));
 2243
 2244    GGML_TENSOR_UNARY_OP_LOCALS
 2245
 2246    const auto [ir0, ir1] = get_thread_range(params, src0);
 2247
 2248    bool (*bipred)(int, int);
 2249
 2250    switch (ttype) {
 2251        case GGML_TRI_TYPE_LOWER:      bipred = [](int i, int r) { return i <  r; }; break;
 2252        case GGML_TRI_TYPE_LOWER_DIAG: bipred = [](int i, int r) { return i <= r; }; break;
 2253        case GGML_TRI_TYPE_UPPER:      bipred = [](int i, int r) { return i >  r; }; break;
 2254        case GGML_TRI_TYPE_UPPER_DIAG: bipred = [](int i, int r) { return i >= r; }; break;
 2255        default: GGML_ABORT("invalid tri type");
 2256    }
 2257
 2258    for (int64_t ir = ir0; ir < ir1; ++ir) {
 2259        const int64_t i03 = ir/(ne02*ne01);
 2260        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
 2261        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
 2262
 2263        const float * src_ptr = (const float  *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
 2264              float * dst_ptr = (      float  *) ((      char *) dst->data  + i03*nb3  + i02*nb2  + i01*nb1);
 2265
 2266        for (int i0 = 0; i0 < ne0; ++i0) {
 2267            dst_ptr[i0] = bipred(i0, i01) ? src_ptr[i0] : 0.0f;
 2268        }
 2269    }
 2270}
 2271
 2272void ggml_compute_forward_tri(const ggml_compute_params * params, ggml_tensor * dst) {
 2273    const ggml_tensor * src0 = dst->src[0];
 2274
 2275    switch (src0->type) {
 2276        case GGML_TYPE_F32:
 2277            {
 2278                ggml_compute_forward_tri_f32(params, dst);
 2279            } break;
 2280        default:
 2281            {
 2282                GGML_ABORT("fatal error");
 2283            }
 2284    }
 2285}
 2286
 2287// ggml_compute_forward_gelu_erf
 2288
 2289static void ggml_compute_forward_gelu_erf_f32(
 2290        const ggml_compute_params * params,
 2291        ggml_tensor * dst) {
 2292
 2293    const ggml_tensor * src0 = dst->src[0];
 2294
 2295    assert(ggml_is_contiguous_rows(src0));
 2296    assert(ggml_are_same_shape(src0, dst));
 2297
 2298    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2299    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2300    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2301    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2302
 2303    const int ith = params->ith;
 2304    const int nth = params->nth;
 2305
 2306    const int nc = src0->ne[0];
 2307    const int nr = ggml_nrows(src0);
 2308
 2309    // rows per thread
 2310    const int dr = (nr + nth - 1)/nth;
 2311
 2312    // row range for this thread
 2313    const int ir0 = dr*ith;
 2314    const int ir1 = MIN(ir0 + dr, nr);
 2315
 2316    for (int ir = ir0; ir < ir1; ++ir) {
 2317        const int i3 = ir/(ne02*ne01);
 2318        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2319        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2320
 2321        ggml_vec_gelu_erf_f32(nc,
 2322                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2323                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2324
 2325#ifndef NDEBUG
 2326        for (int k = 0; k < nc; k++) {
 2327            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2328            GGML_UNUSED(x);
 2329            assert(!isnan(x));
 2330            assert(!isinf(x));
 2331        }
 2332#endif
 2333    }
 2334}
 2335
 2336static void ggml_compute_forward_gelu_erf_f16(
 2337    const ggml_compute_params * params,
 2338    ggml_tensor * dst) {
 2339
 2340    const ggml_tensor * src0 = dst->src[0];
 2341
 2342    assert(ggml_is_contiguous_rows(src0));
 2343    assert(ggml_are_same_shape(src0, dst));
 2344
 2345    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2346    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2347    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2348    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2349
 2350    const int ith = params->ith;
 2351    const int nth = params->nth;
 2352
 2353    const int nc = src0->ne[0];
 2354    const int nr = ggml_nrows(src0);
 2355
 2356    // rows per thread
 2357    const int dr = (nr + nth - 1)/nth;
 2358
 2359    // row range for this thread
 2360    const int ir0 = dr*ith;
 2361    const int ir1 = MIN(ir0 + dr, nr);
 2362
 2363    for (int ir = ir0; ir < ir1; ++ir) {
 2364        const int i3 = ir/(ne02*ne01);
 2365        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2366        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2367
 2368        ggml_vec_gelu_erf_f16(nc,
 2369                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2370                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2371
 2372#ifndef NDEBUG
 2373        for (int k = 0; k < nc; k++) {
 2374            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2375            const float v = GGML_CPU_FP16_TO_FP32(x);
 2376            GGML_UNUSED(v);
 2377            assert(!isnan(v));
 2378            assert(!isinf(v));
 2379        }
 2380#endif
 2381    }
 2382}
 2383
 2384static void ggml_compute_forward_gelu_erf(
 2385        const ggml_compute_params * params,
 2386        ggml_tensor * dst) {
 2387
 2388    const ggml_tensor * src0 = dst->src[0];
 2389
 2390    switch (src0->type) {
 2391        case GGML_TYPE_F32:
 2392            {
 2393                ggml_compute_forward_gelu_erf_f32(params, dst);
 2394            } break;
 2395        case GGML_TYPE_F16:
 2396            {
 2397                ggml_compute_forward_gelu_erf_f16(params, dst);
 2398            } break;
 2399        default:
 2400            {
 2401                GGML_ABORT("fatal error");
 2402            }
 2403    }
 2404}
 2405
 2406// ggml_compute_forward_gelu_quick
 2407
 2408static void ggml_compute_forward_gelu_quick_f32(
 2409        const ggml_compute_params * params,
 2410        ggml_tensor * dst) {
 2411
 2412    const ggml_tensor * src0 = dst->src[0];
 2413
 2414    assert(ggml_is_contiguous_rows(src0));
 2415    assert(ggml_are_same_shape(src0, dst));
 2416
 2417    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2418    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2419    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2420    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2421
 2422    const int ith = params->ith;
 2423    const int nth = params->nth;
 2424
 2425    const int nc = src0->ne[0];
 2426    const int nr = ggml_nrows(src0);
 2427
 2428    // rows per thread
 2429    const int dr = (nr + nth - 1)/nth;
 2430
 2431    // row range for this thread
 2432    const int ir0 = dr*ith;
 2433    const int ir1 = MIN(ir0 + dr, nr);
 2434
 2435    for (int ir = ir0; ir < ir1; ++ir) {
 2436        const int i3 = ir/(ne02*ne01);
 2437        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2438        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2439
 2440        ggml_vec_gelu_quick_f32(nc,
 2441                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2442                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2443
 2444#ifndef NDEBUG
 2445        for (int k = 0; k < nc; k++) {
 2446            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2447            GGML_UNUSED(x);
 2448            assert(!isnan(x));
 2449            assert(!isinf(x));
 2450        }
 2451#endif
 2452    }
 2453}
 2454
 2455static void ggml_compute_forward_gelu_quick_f16(
 2456    const ggml_compute_params * params,
 2457    ggml_tensor * dst) {
 2458
 2459    const ggml_tensor * src0 = dst->src[0];
 2460
 2461    assert(ggml_is_contiguous_rows(src0));
 2462    assert(ggml_are_same_shape(src0, dst));
 2463
 2464    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2465    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2466    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2467    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2468
 2469    const int ith = params->ith;
 2470    const int nth = params->nth;
 2471
 2472    const int nc = src0->ne[0];
 2473    const int nr = ggml_nrows(src0);
 2474
 2475    // rows per thread
 2476    const int dr = (nr + nth - 1)/nth;
 2477
 2478    // row range for this thread
 2479    const int ir0 = dr*ith;
 2480    const int ir1 = MIN(ir0 + dr, nr);
 2481
 2482    for (int ir = ir0; ir < ir1; ++ir) {
 2483        const int i3 = ir/(ne02*ne01);
 2484        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2485        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2486
 2487        ggml_vec_gelu_quick_f16(nc,
 2488                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2489                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2490
 2491#ifndef NDEBUG
 2492        for (int k = 0; k < nc; k++) {
 2493            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2494            const float v = GGML_CPU_FP16_TO_FP32(x);
 2495            GGML_UNUSED(v);
 2496            assert(!isnan(v));
 2497            assert(!isinf(v));
 2498        }
 2499#endif
 2500    }
 2501}
 2502
 2503static void ggml_compute_forward_gelu_quick(
 2504        const ggml_compute_params * params,
 2505        ggml_tensor * dst) {
 2506
 2507    const ggml_tensor * src0 = dst->src[0];
 2508
 2509    switch (src0->type) {
 2510        case GGML_TYPE_F32:
 2511            {
 2512                ggml_compute_forward_gelu_quick_f32(params, dst);
 2513            } break;
 2514        case GGML_TYPE_F16:
 2515            {
 2516                ggml_compute_forward_gelu_quick_f16(params, dst);
 2517            } break;
 2518        default:
 2519            {
 2520                GGML_ABORT("fatal error");
 2521            }
 2522    }
 2523}
 2524
 2525// ggml_compute_forward_silu
 2526
 2527static void ggml_compute_forward_silu_f32(
 2528        const ggml_compute_params * params,
 2529        ggml_tensor * dst) {
 2530
 2531    const ggml_tensor * src0 = dst->src[0];
 2532
 2533    assert(ggml_is_contiguous_rows(src0));
 2534    assert(ggml_are_same_shape(src0, dst));
 2535
 2536    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2537    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2538    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2539    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2540
 2541    const int ith = params->ith;
 2542    const int nth = params->nth;
 2543
 2544    const int nc = src0->ne[0];
 2545    const int nr = ggml_nrows(src0);
 2546
 2547    // rows per thread
 2548    const int dr = (nr + nth - 1)/nth;
 2549
 2550    // row range for this thread
 2551    const int ir0 = dr*ith;
 2552    const int ir1 = MIN(ir0 + dr, nr);
 2553
 2554    for (int ir = ir0; ir < ir1; ++ir) {
 2555        const int i3 = ir/(ne02*ne01);
 2556        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2557        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2558
 2559        ggml_vec_silu_f32(nc,
 2560                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2561                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2562
 2563#ifndef NDEBUG
 2564        for (int k = 0; k < nc; k++) {
 2565            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
 2566            GGML_UNUSED(x);
 2567            assert(!isnan(x));
 2568            assert(!isinf(x));
 2569        }
 2570#endif
 2571    }
 2572}
 2573
 2574static void ggml_compute_forward_silu_f16(
 2575    const ggml_compute_params * params,
 2576    ggml_tensor * dst) {
 2577
 2578    const ggml_tensor * src0 = dst->src[0];
 2579
 2580    assert(ggml_is_contiguous_rows(src0));
 2581    assert(ggml_are_same_shape(src0, dst));
 2582
 2583    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 2584    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb)
 2585    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 2586    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 2587
 2588    const int ith = params->ith;
 2589    const int nth = params->nth;
 2590
 2591    const int nc = src0->ne[0];
 2592    const int nr = ggml_nrows(src0);
 2593
 2594    // rows per thread
 2595    const int dr = (nr + nth - 1)/nth;
 2596
 2597    // row range for this thread
 2598    const int ir0 = dr*ith;
 2599    const int ir1 = MIN(ir0 + dr, nr);
 2600
 2601    for (int ir = ir0; ir < ir1; ++ir) {
 2602        const int i3 = ir/(ne02*ne01);
 2603        const int i2 = (ir - i3*ne02*ne01)/ne01;
 2604        const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
 2605
 2606        ggml_vec_silu_f16(nc,
 2607                (ggml_fp16_t *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1),
 2608                (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 2609
 2610#ifndef NDEBUG
 2611        for (int k = 0; k < nc; k++) {
 2612            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k];
 2613            const float v = GGML_CPU_FP16_TO_FP32(x);
 2614            GGML_UNUSED(v);
 2615            assert(!isnan(v));
 2616            assert(!isinf(v));
 2617        }
 2618#endif
 2619    }
 2620}
 2621
 2622static void ggml_compute_forward_silu(
 2623        const ggml_compute_params * params,
 2624        ggml_tensor * dst) {
 2625
 2626    const ggml_tensor * src0 = dst->src[0];
 2627
 2628    switch (src0->type) {
 2629        case GGML_TYPE_F32:
 2630            {
 2631                ggml_compute_forward_silu_f32(params, dst);
 2632            } break;
 2633        case GGML_TYPE_F16:
 2634            {
 2635                ggml_compute_forward_silu_f16(params, dst);
 2636            } break;
 2637        default:
 2638            {
 2639                GGML_ABORT("fatal error");
 2640            }
 2641    }
 2642}
 2643// ggml_compute_forward_leaky_relu
 2644
 2645static void ggml_compute_forward_leaky_relu_f32(
 2646        const ggml_compute_params * params,
 2647        ggml_tensor * dst) {
 2648
 2649    const ggml_tensor * src0 = dst->src[0];
 2650
 2651    if (params->ith != 0) {
 2652        return;
 2653    }
 2654
 2655    assert(ggml_is_contiguous_1(src0));
 2656    assert(ggml_is_contiguous_1(dst));
 2657    assert(ggml_are_same_shape(src0, dst));
 2658
 2659    const int n  = ggml_nrows(src0);
 2660    const int nc = src0->ne[0];
 2661
 2662    float negative_slope;
 2663    memcpy(&negative_slope, dst->op_params, sizeof(float));
 2664
 2665    assert(dst->nb[0]  == sizeof(float));
 2666    assert(src0->nb[0] == sizeof(float));
 2667
 2668    for (int i = 0; i < n; i++) {
 2669        ggml_vec_leaky_relu_f32(nc,
 2670                (float *) ((char *) dst->data  + i*( dst->nb[1])),
 2671                (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
 2672    }
 2673}
 2674
 2675static void ggml_compute_forward_leaky_relu_f16(
 2676    const ggml_compute_params * params,
 2677    ggml_tensor * dst) {
 2678
 2679    const ggml_tensor * src0 = dst->src[0];
 2680
 2681    if (params->ith != 0) {
 2682        return;
 2683    }
 2684
 2685    assert(ggml_is_contiguous_1(src0));
 2686    assert(ggml_is_contiguous_1(dst));
 2687    assert(ggml_are_same_shape(src0, dst));
 2688
 2689    const int n  = ggml_nrows(src0);
 2690    const int nc = src0->ne[0];
 2691
 2692    float negative_slope;
 2693    memcpy(&negative_slope, dst->op_params, sizeof(float));
 2694
 2695    assert(dst->nb[0]  == sizeof(ggml_fp16_t));
 2696    assert(src0->nb[0] == sizeof(ggml_fp16_t));
 2697
 2698    for (int i = 0; i < n; i++) {
 2699        ggml_vec_leaky_relu_f16(nc,
 2700                (ggml_fp16_t *) ((char *) dst->data  + i*( dst->nb[1])),
 2701                (ggml_fp16_t *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
 2702    }
 2703}
 2704
 2705void ggml_compute_forward_leaky_relu(
 2706        const ggml_compute_params * params,
 2707        ggml_tensor * dst) {
 2708
 2709    const ggml_tensor * src0 = dst->src[0];
 2710
 2711    switch (src0->type) {
 2712        case GGML_TYPE_F32:
 2713            {
 2714                ggml_compute_forward_leaky_relu_f32(params, dst);
 2715            } break;
 2716        case GGML_TYPE_F16:
 2717            {
 2718                ggml_compute_forward_leaky_relu_f16(params, dst);
 2719            } break;
 2720        default:
 2721            {
 2722                GGML_ABORT("fatal error");
 2723            }
 2724    }
 2725}
 2726
 2727// ggml_compute_forward_silu_back
 2728
 2729static void ggml_compute_forward_silu_back_f32(
 2730        const ggml_compute_params * params,
 2731        ggml_tensor * dst) {
 2732
 2733    const ggml_tensor * grad = dst->src[0];
 2734    const ggml_tensor * src1 = dst->src[1];
 2735
 2736    assert(ggml_is_contiguous_1(grad));
 2737    assert(ggml_is_contiguous_1(src1));
 2738    assert(ggml_is_contiguous_1(dst));
 2739    assert(ggml_are_same_shape(src1, dst));
 2740    assert(ggml_are_same_shape(src1, grad));
 2741
 2742    const int ith = params->ith;
 2743    const int nth = params->nth;
 2744
 2745    const int nc = src1->ne[0];
 2746    const int nr = ggml_nrows(src1);
 2747
 2748    // rows per thread
 2749    const int dr = (nr + nth - 1)/nth;
 2750
 2751    // row range for this thread
 2752    const int ir0 = dr*ith;
 2753    const int ir1 = MIN(ir0 + dr, nr);
 2754
 2755    for (int i1 = ir0; i1 < ir1; i1++) {
 2756        ggml_vec_silu_backward_f32(nc,
 2757                (float *) ((char *) dst->data  + i1*( dst->nb[1])),
 2758                (float *) ((char *) src1->data + i1*(src1->nb[1])),
 2759                (float *) ((char *) grad->data + i1*(grad->nb[1])));
 2760
 2761#ifndef NDEBUG
 2762        for (int k = 0; k < nc; k++) {
 2763            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2764            GGML_UNUSED(x);
 2765            assert(!isnan(x));
 2766            assert(!isinf(x));
 2767        }
 2768#endif
 2769    }
 2770}
 2771
 2772static void ggml_compute_forward_silu_back_f16(
 2773    const ggml_compute_params * params,
 2774    ggml_tensor * dst) {
 2775
 2776    const ggml_tensor * grad = dst->src[0];
 2777    const ggml_tensor * src1 = dst->src[1];
 2778
 2779    assert(ggml_is_contiguous_1(grad));
 2780    assert(ggml_is_contiguous_1(src1));
 2781    assert(ggml_is_contiguous_1(dst));
 2782    assert(ggml_are_same_shape(src1, dst));
 2783    assert(ggml_are_same_shape(src1, grad));
 2784
 2785    const int ith = params->ith;
 2786    const int nth = params->nth;
 2787
 2788    const int nc = src1->ne[0];
 2789    const int nr = ggml_nrows(src1);
 2790
 2791    // rows per thread
 2792    const int dr = (nr + nth - 1)/nth;
 2793
 2794    // row range for this thread
 2795    const int ir0 = dr*ith;
 2796    const int ir1 = MIN(ir0 + dr, nr);
 2797
 2798    for (int i1 = ir0; i1 < ir1; i1++) {
 2799        ggml_vec_silu_backward_f16(nc,
 2800                (ggml_fp16_t *) ((char *) dst->data  + i1*( dst->nb[1])),
 2801                (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])),
 2802                (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1])));
 2803
 2804    #ifndef NDEBUG
 2805        for (int k = 0; k < nc; k++) {
 2806            const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2807            const float v = GGML_CPU_FP16_TO_FP32(x);
 2808            GGML_UNUSED(v);
 2809            assert(!isnan(v));
 2810            assert(!isinf(v));
 2811        }
 2812    #endif
 2813    }
 2814}
 2815
 2816void ggml_compute_forward_silu_back(
 2817        const ggml_compute_params * params,
 2818        ggml_tensor * dst) {
 2819
 2820    const ggml_tensor * src0 = dst->src[0];
 2821
 2822    switch (src0->type) {
 2823        case GGML_TYPE_F32:
 2824            {
 2825                ggml_compute_forward_silu_back_f32(params, dst);
 2826            } break;
 2827        case GGML_TYPE_F16:
 2828            {
 2829                ggml_compute_forward_silu_back_f16(params, dst);
 2830            } break;
 2831        default:
 2832            {
 2833                GGML_ABORT("fatal error");
 2834            }
 2835    }
 2836}
 2837
 2838// ggml_compute_forward_reglu
 2839
 2840static void ggml_compute_forward_reglu_f32(
 2841        const ggml_compute_params * params,
 2842        ggml_tensor * dst) {
 2843
 2844    const ggml_tensor * src0 = dst->src[0];
 2845    const ggml_tensor * src1 = dst->src[1];
 2846    char * src0_d = (char *) src0->data;
 2847    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 2848    const size_t src0_o = src0->nb[1];
 2849    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 2850
 2851    GGML_ASSERT(ggml_is_contiguous_1(src0));
 2852    GGML_ASSERT(ggml_is_contiguous_1(dst));
 2853
 2854    if (src1) {
 2855        GGML_ASSERT(ggml_is_contiguous_1(src1));
 2856        GGML_ASSERT(src0->type == src1->type);
 2857    }
 2858
 2859    const int ith = params->ith;
 2860    const int nth = params->nth;
 2861
 2862    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 2863    const int nr = ggml_nrows(src0);
 2864
 2865    GGML_ASSERT(dst->ne[0] == nc);
 2866    GGML_ASSERT(ggml_nrows(dst) == nr);
 2867
 2868    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 2869
 2870    // rows per thread
 2871    const int dr = (nr + nth - 1)/nth;
 2872
 2873    // row range for this thread
 2874    const int ir0 = dr*ith;
 2875    const int ir1 = MIN(ir0 + dr, nr);
 2876
 2877    for (int i1 = ir0; i1 < ir1; i1++) {
 2878        float * src0_p = (float *) (src0_d + i1*src0_o);
 2879        float * src1_p = (float *) (src1_d + i1*src1_o);
 2880
 2881        if (!src1) {
 2882            src0_p += swapped ? nc : 0;
 2883            src1_p += swapped ? 0 : nc;
 2884        }
 2885
 2886        ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 2887
 2888#ifndef NDEBUG
 2889        for (int k = 0; k < nc; k++) {
 2890            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2891            GGML_UNUSED(x);
 2892            assert(!isnan(x));
 2893            assert(!isinf(x));
 2894        }
 2895#endif
 2896    }
 2897}
 2898
 2899static void ggml_compute_forward_reglu_f16(
 2900    const ggml_compute_params * params,
 2901    ggml_tensor * dst) {
 2902
 2903    const ggml_tensor * src0 = dst->src[0];
 2904    const ggml_tensor * src1 = dst->src[1];
 2905    char * src0_d = (char *) src0->data;
 2906    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 2907    const size_t src0_o = src0->nb[1];
 2908    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 2909
 2910    GGML_ASSERT(ggml_is_contiguous_1(src0));
 2911    GGML_ASSERT(ggml_is_contiguous_1(dst));
 2912
 2913    if (src1) {
 2914        GGML_ASSERT(ggml_is_contiguous_1(src1));
 2915        GGML_ASSERT(src0->type == src1->type);
 2916    }
 2917
 2918    const int ith = params->ith;
 2919    const int nth = params->nth;
 2920
 2921    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 2922    const int nr = ggml_nrows(src0);
 2923
 2924    GGML_ASSERT(dst->ne[0] == nc);
 2925    GGML_ASSERT(ggml_nrows(dst) == nr);
 2926
 2927    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 2928
 2929    // rows per thread
 2930    const int dr = (nr + nth - 1)/nth;
 2931
 2932    // row range for this thread
 2933    const int ir0 = dr*ith;
 2934    const int ir1 = MIN(ir0 + dr, nr);
 2935
 2936    for (int i1 = ir0; i1 < ir1; i1++) {
 2937        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
 2938        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
 2939
 2940        if (!src1) {
 2941            src0_p += swapped ? nc : 0;
 2942            src1_p += swapped ? 0 : nc;
 2943        }
 2944
 2945        ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 2946
 2947#ifndef NDEBUG
 2948        for (int k = 0; k < nc; k++) {
 2949            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 2950            const float v = GGML_FP16_TO_FP32(x);
 2951            GGML_UNUSED(v);
 2952            assert(!isnan(v));
 2953            assert(!isinf(v));
 2954        }
 2955#endif
 2956    }
 2957}
 2958
 2959static void ggml_compute_forward_reglu(
 2960        const ggml_compute_params * params,
 2961        ggml_tensor * dst) {
 2962
 2963    const ggml_tensor * src0 = dst->src[0];
 2964
 2965    switch (src0->type) {
 2966        case GGML_TYPE_F32:
 2967            {
 2968                ggml_compute_forward_reglu_f32(params, dst);
 2969            } break;
 2970        case GGML_TYPE_F16:
 2971            {
 2972                ggml_compute_forward_reglu_f16(params, dst);
 2973            } break;
 2974        default:
 2975            {
 2976                GGML_ABORT("fatal error");
 2977            }
 2978    }
 2979}
 2980
 2981// ggml_compute_forward_geglu
 2982
 2983static void ggml_compute_forward_geglu_f32(
 2984        const ggml_compute_params * params,
 2985        ggml_tensor * dst) {
 2986
 2987    const ggml_tensor * src0 = dst->src[0];
 2988    const ggml_tensor * src1 = dst->src[1];
 2989    char * src0_d = (char *) src0->data;
 2990    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 2991    const size_t src0_o = src0->nb[1];
 2992    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 2993
 2994    GGML_ASSERT(ggml_is_contiguous_1(src0));
 2995    GGML_ASSERT(ggml_is_contiguous_1(dst));
 2996
 2997    if (src1) {
 2998        GGML_ASSERT(ggml_is_contiguous_1(src1));
 2999        GGML_ASSERT(src0->type == src1->type);
 3000    }
 3001
 3002    const int ith = params->ith;
 3003    const int nth = params->nth;
 3004
 3005    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3006    const int nr = ggml_nrows(src0);
 3007
 3008    GGML_ASSERT(dst->ne[0] == nc);
 3009    GGML_ASSERT(ggml_nrows(dst) == nr);
 3010
 3011    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3012
 3013    // rows per thread
 3014    const int dr = (nr + nth - 1)/nth;
 3015
 3016    // row range for this thread
 3017    const int ir0 = dr*ith;
 3018    const int ir1 = MIN(ir0 + dr, nr);
 3019
 3020    for (int i1 = ir0; i1 < ir1; i1++) {
 3021        float * src0_p = (float *) (src0_d + i1*src0_o);
 3022        float * src1_p = (float *) (src1_d + i1*src1_o);
 3023
 3024        if (!src1) {
 3025            src0_p += swapped ? nc : 0;
 3026            src1_p += swapped ? 0 : nc;
 3027        }
 3028
 3029        ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3030
 3031#ifndef NDEBUG
 3032        for (int k = 0; k < nc; k++) {
 3033            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3034            GGML_UNUSED(x);
 3035            assert(!isnan(x));
 3036            assert(!isinf(x));
 3037        }
 3038#endif
 3039    }
 3040}
 3041
 3042static void ggml_compute_forward_geglu_f16(
 3043    const ggml_compute_params * params,
 3044    ggml_tensor * dst) {
 3045
 3046    const ggml_tensor * src0 = dst->src[0];
 3047    const ggml_tensor * src1 = dst->src[1];
 3048    char * src0_d = (char *) src0->data;
 3049    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3050    const size_t src0_o = src0->nb[1];
 3051    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3052
 3053    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3054    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3055
 3056    if (src1) {
 3057        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3058        GGML_ASSERT(src0->type == src1->type);
 3059    }
 3060
 3061    const int ith = params->ith;
 3062    const int nth = params->nth;
 3063
 3064    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3065    const int nr = ggml_nrows(src0);
 3066
 3067    GGML_ASSERT(dst->ne[0] == nc);
 3068    GGML_ASSERT(ggml_nrows(dst) == nr);
 3069
 3070    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3071
 3072    // rows per thread
 3073    const int dr = (nr + nth - 1)/nth;
 3074
 3075    // row range for this thread
 3076    const int ir0 = dr*ith;
 3077    const int ir1 = MIN(ir0 + dr, nr);
 3078
 3079    for (int i1 = ir0; i1 < ir1; i1++) {
 3080        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
 3081        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
 3082
 3083        if (!src1) {
 3084            src0_p += swapped ? nc : 0;
 3085            src1_p += swapped ? 0 : nc;
 3086        }
 3087
 3088        ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3089
 3090#ifndef NDEBUG
 3091        for (int k = 0; k < nc; k++) {
 3092            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3093            const float v = GGML_FP16_TO_FP32(x);
 3094            GGML_UNUSED(v);
 3095            assert(!isnan(v));
 3096            assert(!isinf(v));
 3097        }
 3098#endif
 3099    }
 3100}
 3101
 3102static void ggml_compute_forward_geglu(
 3103        const ggml_compute_params * params,
 3104        ggml_tensor * dst) {
 3105
 3106    const ggml_tensor * src0 = dst->src[0];
 3107
 3108    switch (src0->type) {
 3109        case GGML_TYPE_F32:
 3110            {
 3111                ggml_compute_forward_geglu_f32(params, dst);
 3112            } break;
 3113        case GGML_TYPE_F16:
 3114            {
 3115                ggml_compute_forward_geglu_f16(params, dst);
 3116            } break;
 3117        default:
 3118            {
 3119                GGML_ABORT("fatal error");
 3120            }
 3121    }
 3122}
 3123
 3124// ggml_compute_forward_swiglu
 3125
 3126static void ggml_compute_forward_swiglu_f32(
 3127        const ggml_compute_params * params,
 3128        ggml_tensor * dst) {
 3129
 3130    const ggml_tensor * src0 = dst->src[0];
 3131    const ggml_tensor * src1 = dst->src[1];
 3132    char * src0_d = (char *) src0->data;
 3133    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3134    const size_t src0_o = src0->nb[1];
 3135    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3136
 3137    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3138    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3139
 3140    if (src1) {
 3141        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3142        GGML_ASSERT(src0->type == src1->type);
 3143    }
 3144
 3145    const int ith = params->ith;
 3146    const int nth = params->nth;
 3147
 3148    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3149    const int nr = ggml_nrows(src0);
 3150
 3151    GGML_ASSERT(dst->ne[0] == nc);
 3152    GGML_ASSERT(ggml_nrows(dst) == nr);
 3153
 3154    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3155
 3156    // rows per thread
 3157    const int dr = (nr + nth - 1)/nth;
 3158
 3159    // row range for this thread
 3160    const int ir0 = dr*ith;
 3161    const int ir1 = MIN(ir0 + dr, nr);
 3162
 3163    for (int i1 = ir0; i1 < ir1; i1++) {
 3164        float * src0_p = (float *) (src0_d + i1*src0_o);
 3165        float * src1_p = (float *) (src1_d + i1*src1_o);
 3166
 3167        if (!src1) {
 3168            src0_p += swapped ? nc : 0;
 3169            src1_p += swapped ? 0 : nc;
 3170        }
 3171
 3172        ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3173
 3174#ifndef NDEBUG
 3175        for (int k = 0; k < nc; k++) {
 3176            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3177            GGML_UNUSED(x);
 3178            assert(!isnan(x));
 3179            assert(!isinf(x));
 3180        }
 3181#endif
 3182    }
 3183}
 3184
 3185static void ggml_compute_forward_swiglu_f16(
 3186    const ggml_compute_params * params,
 3187    ggml_tensor * dst) {
 3188
 3189    const ggml_tensor * src0 = dst->src[0];
 3190    const ggml_tensor * src1 = dst->src[1];
 3191    char * src0_d = (char *) src0->data;
 3192    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3193    const size_t src0_o = src0->nb[1];
 3194    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3195
 3196    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3197    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3198
 3199    if (src1) {
 3200        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3201        GGML_ASSERT(src0->type == src1->type);
 3202    }
 3203
 3204    const int ith = params->ith;
 3205    const int nth = params->nth;
 3206
 3207    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3208    const int nr = ggml_nrows(src0);
 3209
 3210    GGML_ASSERT(dst->ne[0] == nc);
 3211    GGML_ASSERT(ggml_nrows(dst) == nr);
 3212
 3213    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3214
 3215    // rows per thread
 3216    const int dr = (nr + nth - 1)/nth;
 3217
 3218    // row range for this thread
 3219    const int ir0 = dr*ith;
 3220    const int ir1 = MIN(ir0 + dr, nr);
 3221
 3222    for (int i1 = ir0; i1 < ir1; i1++) {
 3223        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
 3224        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
 3225
 3226        if (!src1) {
 3227            src0_p += swapped ? nc : 0;
 3228            src1_p += swapped ? 0 : nc;
 3229        }
 3230
 3231        ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3232
 3233#ifndef NDEBUG
 3234        for (int k = 0; k < nc; k++) {
 3235            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3236            const float v = GGML_FP16_TO_FP32(x);
 3237            GGML_UNUSED(v);
 3238            assert(!isnan(v));
 3239            assert(!isinf(v));
 3240        }
 3241#endif
 3242    }
 3243}
 3244
 3245static void ggml_compute_forward_swiglu(
 3246        const ggml_compute_params * params,
 3247        ggml_tensor * dst) {
 3248
 3249    const ggml_tensor * src0 = dst->src[0];
 3250
 3251    switch (src0->type) {
 3252        case GGML_TYPE_F32:
 3253            {
 3254                ggml_compute_forward_swiglu_f32(params, dst);
 3255            } break;
 3256        case GGML_TYPE_F16:
 3257            {
 3258                ggml_compute_forward_swiglu_f16(params, dst);
 3259            } break;
 3260        default:
 3261            {
 3262                GGML_ABORT("fatal error");
 3263            }
 3264    }
 3265}
 3266
 3267// ggml_compute_forward_swiglu_oai
 3268
 3269static void ggml_compute_forward_swiglu_oai_f32(
 3270        const ggml_compute_params * params,
 3271        ggml_tensor * dst) {
 3272
 3273    const ggml_tensor * src0 = dst->src[0];
 3274    const ggml_tensor * src1 = dst->src[1];
 3275    char * src0_d = (char *) src0->data;
 3276    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3277    const size_t src0_o = src0->nb[1];
 3278    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3279
 3280    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3281    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3282
 3283    if (src1) {
 3284        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3285        GGML_ASSERT(src0->type == src1->type);
 3286    }
 3287
 3288    const int ith = params->ith;
 3289    const int nth = params->nth;
 3290
 3291    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3292    const int nr = ggml_nrows(src0);
 3293
 3294    GGML_ASSERT(dst->ne[0] == nc);
 3295    GGML_ASSERT(ggml_nrows(dst) == nr);
 3296
 3297    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3298    const float alpha = ggml_get_op_params_f32(dst, 2);
 3299    const float limit = ggml_get_op_params_f32(dst, 3);
 3300
 3301    // rows per thread
 3302    const int dr = (nr + nth - 1)/nth;
 3303
 3304    // row range for this thread
 3305    const int ir0 = dr*ith;
 3306    const int ir1 = MIN(ir0 + dr, nr);
 3307
 3308    for (int i1 = ir0; i1 < ir1; i1++) {
 3309        float * src0_p = (float *) (src0_d + i1*src0_o);
 3310        float * src1_p = (float *) (src1_d + i1*src1_o);
 3311        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));
 3312
 3313        if (!src1) {
 3314            src0_p += swapped ? nc : 0;
 3315            src1_p += swapped ? 0 : nc;
 3316        }
 3317
 3318        for (int k = 0; k < nc; k++) {
 3319            const float x = std::min(src0_p[k], limit);
 3320            const float y = std::clamp(src1_p[k], -limit, limit);
 3321            const float out_glu = x / (1.f + expf(alpha * (-x)));
 3322            dst_p[k] = out_glu * (y + 1.f);
 3323        }
 3324
 3325#ifndef NDEBUG
 3326        for (int k = 0; k < nc; k++) {
 3327            const float x = dst_p[k];
 3328            GGML_UNUSED(x);
 3329            assert(!isnan(x));
 3330            assert(!isinf(x));
 3331        }
 3332#endif
 3333    }
 3334}
 3335
 3336static void ggml_compute_forward_swiglu_oai(
 3337        const ggml_compute_params * params,
 3338        ggml_tensor * dst) {
 3339
 3340    const ggml_tensor * src0 = dst->src[0];
 3341
 3342    switch (src0->type) {
 3343        case GGML_TYPE_F32:
 3344            {
 3345                ggml_compute_forward_swiglu_oai_f32(params, dst);
 3346            } break;
 3347        default:
 3348            {
 3349                GGML_ABORT("fatal error");
 3350            }
 3351    }
 3352}
 3353
 3354// ggml_compute_forward_geglu_erf
 3355
 3356static void ggml_compute_forward_geglu_erf_f32(
 3357        const ggml_compute_params * params,
 3358        ggml_tensor * dst) {
 3359
 3360    const ggml_tensor * src0 = dst->src[0];
 3361    const ggml_tensor * src1 = dst->src[1];
 3362    char * src0_d = (char *) src0->data;
 3363    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3364    const size_t src0_o = src0->nb[1];
 3365    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3366
 3367    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3368    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3369
 3370    if (src1) {
 3371        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3372        GGML_ASSERT(src0->type == src1->type);
 3373    }
 3374
 3375    const int ith = params->ith;
 3376    const int nth = params->nth;
 3377
 3378    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3379    const int nr = ggml_nrows(src0);
 3380
 3381    GGML_ASSERT(dst->ne[0] == nc);
 3382    GGML_ASSERT(ggml_nrows(dst) == nr);
 3383
 3384    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3385
 3386    // rows per thread
 3387    const int dr = (nr + nth - 1)/nth;
 3388
 3389    // row range for this thread
 3390    const int ir0 = dr*ith;
 3391    const int ir1 = MIN(ir0 + dr, nr);
 3392
 3393    for (int i1 = ir0; i1 < ir1; i1++) {
 3394        float * src0_p = (float *) (src0_d + i1*src0_o);
 3395        float * src1_p = (float *) (src1_d + i1*src1_o);
 3396
 3397        if (!src1) {
 3398            src0_p += swapped ? nc : 0;
 3399            src1_p += swapped ? 0 : nc;
 3400        }
 3401
 3402        ggml_vec_geglu_erf_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3403
 3404#ifndef NDEBUG
 3405        for (int k = 0; k < nc; k++) {
 3406            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3407            GGML_UNUSED(x);
 3408            assert(!isnan(x));
 3409            assert(!isinf(x));
 3410        }
 3411#endif
 3412    }
 3413}
 3414
 3415static void ggml_compute_forward_geglu_erf_f16(
 3416    const ggml_compute_params * params,
 3417    ggml_tensor * dst) {
 3418
 3419    const ggml_tensor * src0 = dst->src[0];
 3420    const ggml_tensor * src1 = dst->src[1];
 3421    char * src0_d = (char *) src0->data;
 3422    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3423    const size_t src0_o = src0->nb[1];
 3424    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3425
 3426    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3427    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3428
 3429    if (src1) {
 3430        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3431        GGML_ASSERT(src0->type == src1->type);
 3432    }
 3433
 3434    const int ith = params->ith;
 3435    const int nth = params->nth;
 3436
 3437    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3438    const int nr = ggml_nrows(src0);
 3439
 3440    GGML_ASSERT(dst->ne[0] == nc);
 3441    GGML_ASSERT(ggml_nrows(dst) == nr);
 3442
 3443    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3444
 3445    // rows per thread
 3446    const int dr = (nr + nth - 1)/nth;
 3447
 3448    // row range for this thread
 3449    const int ir0 = dr*ith;
 3450    const int ir1 = MIN(ir0 + dr, nr);
 3451
 3452    for (int i1 = ir0; i1 < ir1; i1++) {
 3453        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
 3454        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
 3455
 3456        if (!src1) {
 3457            src0_p += swapped ? nc : 0;
 3458            src1_p += swapped ? 0 : nc;
 3459        }
 3460
 3461        ggml_vec_geglu_erf_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3462
 3463#ifndef NDEBUG
 3464        for (int k = 0; k < nc; k++) {
 3465            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3466            const float v = GGML_FP16_TO_FP32(x);
 3467            GGML_UNUSED(v);
 3468            assert(!isnan(v));
 3469            assert(!isinf(v));
 3470        }
 3471#endif
 3472    }
 3473}
 3474
 3475static void ggml_compute_forward_geglu_erf(
 3476        const ggml_compute_params * params,
 3477        ggml_tensor * dst) {
 3478
 3479    const ggml_tensor * src0 = dst->src[0];
 3480
 3481    switch (src0->type) {
 3482        case GGML_TYPE_F32:
 3483            {
 3484                ggml_compute_forward_geglu_erf_f32(params, dst);
 3485            } break;
 3486        case GGML_TYPE_F16:
 3487            {
 3488                ggml_compute_forward_geglu_erf_f16(params, dst);
 3489            } break;
 3490        default:
 3491            {
 3492                GGML_ABORT("fatal error");
 3493            }
 3494    }
 3495}
 3496
 3497// ggml_compute_forward_geglu_quick
 3498
 3499static void ggml_compute_forward_geglu_quick_f32(
 3500        const ggml_compute_params * params,
 3501        ggml_tensor * dst) {
 3502
 3503    const ggml_tensor * src0 = dst->src[0];
 3504    const ggml_tensor * src1 = dst->src[1];
 3505    char * src0_d = (char *) src0->data;
 3506    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3507    const size_t src0_o = src0->nb[1];
 3508    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3509
 3510    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3511    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3512
 3513    if (src1) {
 3514        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3515        GGML_ASSERT(src0->type == src1->type);
 3516    }
 3517
 3518    const int ith = params->ith;
 3519    const int nth = params->nth;
 3520
 3521    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3522    const int nr = ggml_nrows(src0);
 3523
 3524    GGML_ASSERT(dst->ne[0] == nc);
 3525    GGML_ASSERT(ggml_nrows(dst) == nr);
 3526
 3527    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3528
 3529    // rows per thread
 3530    const int dr = (nr + nth - 1)/nth;
 3531
 3532    // row range for this thread
 3533    const int ir0 = dr*ith;
 3534    const int ir1 = MIN(ir0 + dr, nr);
 3535
 3536    for (int i1 = ir0; i1 < ir1; i1++) {
 3537        float * src0_p = (float *) (src0_d + i1*src0_o);
 3538        float * src1_p = (float *) (src1_d + i1*src1_o);
 3539
 3540        if (!src1) {
 3541            src0_p += swapped ? nc : 0;
 3542            src1_p += swapped ? 0 : nc;
 3543        }
 3544
 3545        ggml_vec_geglu_quick_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3546
 3547#ifndef NDEBUG
 3548        for (int k = 0; k < nc; k++) {
 3549            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3550            GGML_UNUSED(x);
 3551            assert(!isnan(x));
 3552            assert(!isinf(x));
 3553        }
 3554#endif
 3555    }
 3556}
 3557
 3558static void ggml_compute_forward_geglu_quick_f16(
 3559    const ggml_compute_params * params,
 3560    ggml_tensor * dst) {
 3561
 3562    const ggml_tensor * src0 = dst->src[0];
 3563    const ggml_tensor * src1 = dst->src[1];
 3564    char * src0_d = (char *) src0->data;
 3565    char * src1_d = (char *) (src1 ? src1->data : src0->data);
 3566    const size_t src0_o = src0->nb[1];
 3567    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
 3568
 3569    GGML_ASSERT(ggml_is_contiguous_1(src0));
 3570    GGML_ASSERT(ggml_is_contiguous_1(dst));
 3571
 3572    if (src1) {
 3573        GGML_ASSERT(ggml_is_contiguous_1(src1));
 3574        GGML_ASSERT(src0->type == src1->type);
 3575    }
 3576
 3577    const int ith = params->ith;
 3578    const int nth = params->nth;
 3579
 3580    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
 3581    const int nr = ggml_nrows(src0);
 3582
 3583    GGML_ASSERT(dst->ne[0] == nc);
 3584    GGML_ASSERT(ggml_nrows(dst) == nr);
 3585
 3586    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
 3587
 3588    // rows per thread
 3589    const int dr = (nr + nth - 1)/nth;
 3590
 3591    // row range for this thread
 3592    const int ir0 = dr*ith;
 3593    const int ir1 = MIN(ir0 + dr, nr);
 3594
 3595    for (int i1 = ir0; i1 < ir1; i1++) {
 3596        ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
 3597        ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
 3598
 3599        if (!src1) {
 3600            src0_p += swapped ? nc : 0;
 3601            src1_p += swapped ? 0 : nc;
 3602        }
 3603
 3604        ggml_vec_geglu_quick_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
 3605
 3606#ifndef NDEBUG
 3607        for (int k = 0; k < nc; k++) {
 3608            const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
 3609            const float v = GGML_FP16_TO_FP32(x);
 3610            GGML_UNUSED(v);
 3611            assert(!isnan(v));
 3612            assert(!isinf(v));
 3613        }
 3614#endif
 3615    }
 3616}
 3617
 3618static void ggml_compute_forward_geglu_quick(
 3619        const ggml_compute_params * params,
 3620        ggml_tensor * dst) {
 3621
 3622    const ggml_tensor * src0 = dst->src[0];
 3623
 3624    switch (src0->type) {
 3625        case GGML_TYPE_F32:
 3626            {
 3627                ggml_compute_forward_geglu_quick_f32(params, dst);
 3628            } break;
 3629        case GGML_TYPE_F16:
 3630            {
 3631                ggml_compute_forward_geglu_quick_f16(params, dst);
 3632            } break;
 3633        default:
 3634            {
 3635                GGML_ABORT("fatal error");
 3636            }
 3637    }
 3638}
 3639
 3640// ggml_compute_forward_norm
 3641
 3642static void ggml_compute_forward_norm_f32(
 3643        const ggml_compute_params * params,
 3644        ggml_tensor * dst) {
 3645
 3646    const ggml_tensor * src0 = dst->src[0];
 3647
 3648    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 3649
 3650    GGML_ASSERT(src0->nb[0] == sizeof(float));
 3651
 3652    const int ith = params->ith;
 3653    const int nth = params->nth;
 3654
 3655    GGML_TENSOR_UNARY_OP_LOCALS
 3656
 3657    float eps;
 3658    memcpy(&eps, dst->op_params, sizeof(float));
 3659
 3660    GGML_ASSERT(eps >= 0.0f);
 3661
 3662    for (int64_t i03 = 0; i03 < ne03; i03++) {
 3663        for (int64_t i02 = 0; i02 < ne02; i02++) {
 3664            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 3665                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 3666
 3667                float sum = 0.0;
 3668                ggml_vec_sum_f32(ne00, &sum, x);
 3669                float mean = sum/ne00;
 3670
 3671                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 3672                float variance = 0;
 3673
 3674#ifdef GGML_USE_ACCELERATE
 3675                mean = -mean;
 3676                vDSP_vsadd(x, 1, &mean, y, 1, ne00);
 3677                vDSP_measqv(y, 1, &variance, ne00);
 3678#else
 3679                variance = ggml_vec_cvar_f32(ne00, y, x, mean);
 3680#endif //GGML_USE_ACCELERATE
 3681
 3682                const float scale = 1.0f/sqrtf(variance + eps);
 3683                ggml_vec_scale_f32(ne00, y, scale);
 3684            }
 3685        }
 3686    }
 3687}
 3688
 3689void ggml_compute_forward_norm(
 3690        const ggml_compute_params * params,
 3691        ggml_tensor * dst) {
 3692
 3693    const ggml_tensor * src0 = dst->src[0];
 3694
 3695    switch (src0->type) {
 3696        case GGML_TYPE_F32:
 3697            {
 3698                ggml_compute_forward_norm_f32(params, dst);
 3699            } break;
 3700        default:
 3701            {
 3702                GGML_ABORT("fatal error");
 3703            }
 3704    }
 3705}
 3706
 3707// ggml_compute_forward_group_rms_norm
 3708
 3709static void ggml_compute_forward_rms_norm_f32(
 3710        const ggml_compute_params * params,
 3711        ggml_tensor * dst) {
 3712
 3713    const ggml_tensor * src0 = dst->src[0];
 3714
 3715    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 3716
 3717    GGML_ASSERT(src0->nb[0] == sizeof(float));
 3718
 3719    const int ith = params->ith;
 3720    const int nth = params->nth;
 3721
 3722    GGML_TENSOR_UNARY_OP_LOCALS
 3723
 3724    float eps;
 3725    memcpy(&eps, dst->op_params, sizeof(float));
 3726
 3727    GGML_ASSERT(eps >= 0.0f);
 3728
 3729    // TODO: optimize
 3730    for (int64_t i03 = 0; i03 < ne03; i03++) {
 3731        for (int64_t i02 = 0; i02 < ne02; i02++) {
 3732            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 3733                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 3734
 3735                ggml_float sum = 0.0;
 3736                for (int64_t i00 = 0; i00 < ne00; i00++) {
 3737                    sum += (ggml_float)(x[i00] * x[i00]);
 3738                }
 3739
 3740                const float mean = sum/ne00;
 3741
 3742                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 3743
 3744                memcpy(y, x, ne00 * sizeof(float));
 3745                // for (int i00 = 0; i00 < ne00; i00++) {
 3746                //     y[i00] = x[i00];
 3747                // }
 3748
 3749                const float scale = 1.0f/sqrtf(mean + eps);
 3750
 3751                // if you hit this, likely you got an inf somewhere earlier
 3752                assert(scale > 0.0f);
 3753
 3754                ggml_vec_scale_f32(ne00, y, scale);
 3755            }
 3756        }
 3757    }
 3758}
 3759
 3760void ggml_compute_forward_rms_norm(
 3761        const ggml_compute_params * params,
 3762        ggml_tensor * dst) {
 3763
 3764    const ggml_tensor * src0 = dst->src[0];
 3765
 3766    switch (src0->type) {
 3767        case GGML_TYPE_F32:
 3768            {
 3769                ggml_compute_forward_rms_norm_f32(params, dst);
 3770            } break;
 3771        default:
 3772            {
 3773                GGML_ABORT("fatal error");
 3774            }
 3775    }
 3776}
 3777
 3778static void ggml_compute_forward_rms_norm_back_f32(
 3779        const ggml_compute_params * params,
 3780        ggml_tensor * dst) {
 3781
 3782    const ggml_tensor * src0 = dst->src[0]; // gradients from forward pass output
 3783    const ggml_tensor * src1 = dst->src[1]; // src1 from forward pass
 3784
 3785    GGML_ASSERT(ggml_are_same_shape(src0, dst) && ggml_are_same_shape(src0, src1));
 3786
 3787    GGML_ASSERT(src0->nb[0] == sizeof(float));
 3788    GGML_ASSERT(src1->nb[0] == sizeof(float));
 3789
 3790    const int ith = params->ith;
 3791    const int nth = params->nth;
 3792
 3793    GGML_TENSOR_BINARY_OP_LOCALS
 3794
 3795    float eps;
 3796    memcpy(&eps, dst->op_params, sizeof(float));
 3797
 3798    // TODO: optimize
 3799    for (int64_t i03 = 0; i03 < ne03; i03++) {
 3800        for (int64_t i02 = 0; i02 < ne02; i02++) {
 3801            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 3802                // src1 is same shape as src0 => same indices
 3803                const int64_t i11 = i01;
 3804                const int64_t i12 = i02;
 3805                const int64_t i13 = i03;
 3806
 3807                const float * dz = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 3808                const float * x  = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13);
 3809
 3810                ggml_float sum_xx  = 0.0;
 3811                ggml_float sum_xdz = 0.0;
 3812
 3813                for (int64_t i00 = 0; i00 < ne00; i00++) {
 3814                    sum_xx  += (ggml_float)(x[i00] * x[i00]);
 3815                    sum_xdz += (ggml_float)(x[i00] * dz[i00]);
 3816                }
 3817
 3818                //const float mean     = (float)(sum_xx)/ne00;
 3819                const float mean_eps = (float)(sum_xx)/ne00 + eps;
 3820                const float sum_eps  = (float)(sum_xx) + eps*ne00;
 3821                //const float mean_xdz = (float)(sum_xdz)/ne00;
 3822                // we could cache rms from forward pass to improve performance.
 3823                // to do this implement ggml_rms and compose ggml_rms_norm using ggml_rms.
 3824                //const float rms      = sqrtf(mean_eps);
 3825                const float rrms     = 1.0f / sqrtf(mean_eps);
 3826                //const float scale    = -rrms/(ne00 * mean_eps); // -1/(n*rms**3)
 3827
 3828                {
 3829                    // z = rms_norm(x)
 3830                    //
 3831                    // rms_norm(src1) =
 3832                    //     scale(
 3833                    //         src1,
 3834                    //         div(
 3835                    //             1,
 3836                    //             sqrt(
 3837                    //                 add(
 3838                    //                     scale(
 3839                    //                         sum(
 3840                    //                             sqr(
 3841                    //                                 src1)),
 3842                    //                         (1.0/N)),
 3843                    //                     eps))));
 3844
 3845                    // postorder:
 3846                    // ## op    args         grad
 3847                    // 00 param src1         grad[#00]
 3848                    // 01 const 1
 3849                    // 02 sqr   (#00)        grad[#02]
 3850                    // 03 sum   (#02)        grad[#03]
 3851                    // 04 const 1/N
 3852                    // 05 scale (#03, #04)   grad[#05]
 3853                    // 06 const eps
 3854                    // 07 add   (#05, #06)   grad[#07]
 3855                    // 08 sqrt  (#07)        grad[#08]
 3856                    // 09 div   (#01,#08)    grad[#09]
 3857                    // 10 scale (#00,#09)    grad[#10]
 3858                    //
 3859                    // backward pass, given grad[#10]
 3860                    // #10: scale
 3861                    // grad[#00] += scale(grad[#10],#09)
 3862                    // grad[#09] += sum(mul(grad[#10],#00))
 3863                    // #09: div
 3864                    // grad[#08] += neg(mul(grad[#09], div(#09,#08)))
 3865                    // #08: sqrt
 3866                    // grad[#07] += mul(grad[#08], div(0.5, #08))
 3867                    // #07: add
 3868                    // grad[#05] += grad[#07]
 3869                    // #05: scale
 3870                    // grad[#03] += scale(grad[#05],#04)
 3871                    // #03: sum
 3872                    // grad[#02] += repeat(grad[#03], #02)
 3873                    // #02:
 3874                    // grad[#00] += scale(mul(#00, grad[#02]), 2.0)
 3875                    //
 3876                    // substitute and simplify:
 3877                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
 3878                    // grad[#02] = repeat(grad[#03], #02)
 3879                    // grad[#02] = repeat(scale(grad[#05],#04), #02)
 3880                    // grad[#02] = repeat(scale(grad[#07],#04), #02)
 3881                    // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02)
 3882                    // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02)
 3883                    // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02)
 3884                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02)
 3885                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02)
 3886                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02)
 3887                    // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)
 3888                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0)
 3889                    // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0)
 3890                    // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0)
 3891                    // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N)))
 3892                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
 3893                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N))
 3894                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,mean_eps*rms) * (-1/N))
 3895                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*mean_eps))
 3896                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*(sum_xx/N+eps)))
 3897                    // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(-1,rms*N*sum_xx+rms*N*eps))
 3898                    // grad[#00] = scale(dz, rrms) + scale(x, sum(mul(dz,x)) * div(-1,rms*N*mean_eps))
 3899                    // grad[#00] = scale(dz, rrms) + scale(x, sum_xdz * div(-1,rms*N*mean_eps))
 3900                    // a = b*c + d*e
 3901                    // a = b*c*f/f + d*e*f/f
 3902                    // a = (b*c*f + d*e*f)*(1/f)
 3903                    // a = (b*c*(1/c) + d*e*(1/c))*(1/(1/c))
 3904                    // a = (b + d*e/c)*c
 3905                    // b = dz, c = rrms, d = x, e = sum_xdz * div(-1,rms*N*mean_eps)
 3906                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)/rrms)*rrms
 3907                    // a = (dz + x*sum_xdz * div(-1,rms*N*mean_eps)*rms)*rrms
 3908                    // a = (dz + x*sum_xdz * div(-rms,rms*N*mean_eps))*rrms
 3909                    // a = (dz + x*sum_xdz * div(-1,N*mean_eps))*rrms
 3910                    // a = (dz + x*div(-sum_xdz,N*mean_eps))*rrms
 3911                    // a = (dz + x*div(-mean_xdz,mean_eps))*rrms
 3912                    // grad[#00] = scale(dz + scale(x, div(-mean_xdz,mean_eps)),rrms)
 3913                    // grad[#00] = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
 3914                    // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
 3915                }
 3916                // dx = scale(dz + scale(x, -mean_xdz/mean_eps),rrms)
 3917                // post-order:
 3918                // dx := x
 3919                // dx := scale(dx,-mean_xdz/mean_eps)
 3920                // dx := add(dx, dz)
 3921                // dx := scale(dx, rrms)
 3922                float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 3923
 3924                // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps)
 3925                ggml_vec_cpy_f32  (ne00, dx, x);
 3926                // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps);
 3927                ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps);
 3928                ggml_vec_acc_f32  (ne00, dx, dz);
 3929                ggml_vec_scale_f32(ne00, dx, rrms);
 3930            }
 3931        }
 3932    }
 3933}
 3934
 3935void ggml_compute_forward_rms_norm_back(
 3936        const ggml_compute_params * params,
 3937        ggml_tensor * dst) {
 3938
 3939    const ggml_tensor * src0 = dst->src[0];
 3940
 3941    switch (src0->type) {
 3942        case GGML_TYPE_F32:
 3943            {
 3944                ggml_compute_forward_rms_norm_back_f32(params, dst);
 3945            } break;
 3946        default:
 3947            {
 3948                GGML_ABORT("fatal error");
 3949            }
 3950    }
 3951}
 3952
 3953// ggml_compute_forward_group_norm
 3954
 3955static void ggml_compute_forward_group_norm_f32(
 3956    const ggml_compute_params * params,
 3957    ggml_tensor * dst) {
 3958
 3959    const ggml_tensor * src0 = dst->src[0];
 3960
 3961    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 3962
 3963    GGML_ASSERT(src0->nb[0] == sizeof(float));
 3964
 3965    const int ith = params->ith;
 3966    const int nth = params->nth;
 3967
 3968    GGML_TENSOR_UNARY_OP_LOCALS
 3969
 3970    // TODO: optimize
 3971
 3972    float eps;
 3973    memcpy(&eps, dst->op_params + 1, sizeof(float));
 3974
 3975    int n_channels = src0->ne[2];
 3976    int n_groups = dst->op_params[0];
 3977    int n_channels_per_group = (n_channels + n_groups - 1) / n_groups;
 3978    for (int i = ith; i < n_groups; i += nth) {
 3979        int start = i * n_channels_per_group;
 3980        int end = start + n_channels_per_group;
 3981        if (end > n_channels) {
 3982            end = n_channels;
 3983        }
 3984        int step = end - start;
 3985
 3986        for (int64_t i03 = 0; i03 < ne03; i03++) {
 3987            ggml_float sum = 0.0;
 3988            for (int64_t i02 = start; i02 < end; i02++) {
 3989                for (int64_t i01 = 0; i01 < ne01; i01++) {
 3990                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
 3991
 3992                    ggml_float sumr = 0.0;
 3993                    for (int64_t i00 = 0; i00 < ne00; i00++) {
 3994                        sumr += (ggml_float)x[i00];
 3995                    }
 3996                    sum += sumr;
 3997                }
 3998            }
 3999            const float mean = sum / (ne00 * ne01 * step);
 4000
 4001            ggml_float sum2 = 0.0;
 4002            for (int64_t i02 = start; i02 < end; i02++) {
 4003                for (int64_t i01 = 0; i01 < ne01; i01++) {
 4004                    const float * x = (float *)((char *) src0->data + i01 * nb01 + i02 * nb02 + i03 * nb03);
 4005
 4006                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
 4007
 4008                    ggml_float sumr = 0.0;
 4009                    for (int64_t i00 = 0; i00 < ne00; i00++) {
 4010                        float v = x[i00] - mean;
 4011                        y[i00] = v;
 4012                        sumr += (ggml_float)(v * v);
 4013                    }
 4014                    sum2 += sumr;
 4015                }
 4016            }
 4017            const float variance = sum2 / (ne00 * ne01 * step);
 4018            const float scale = 1.0f / sqrtf(variance + eps);
 4019
 4020            for (int64_t i02 = start; i02 < end; i02++) {
 4021                for (int64_t i01 = 0; i01 < ne01; i01++) {
 4022                    float * y = (float *)((char *) dst->data + i01 * nb1 + i02 * nb2 + i03 * nb3);
 4023                    ggml_vec_scale_f32(ne00, y, scale);
 4024                }
 4025            }
 4026        }
 4027    }
 4028}
 4029
 4030void ggml_compute_forward_group_norm(
 4031    const ggml_compute_params * params,
 4032    ggml_tensor * dst) {
 4033
 4034    const ggml_tensor * src0 = dst->src[0];
 4035
 4036    switch (src0->type) {
 4037        case GGML_TYPE_F32:
 4038            {
 4039                ggml_compute_forward_group_norm_f32(params, dst);
 4040            } break;
 4041        default:
 4042            {
 4043                GGML_ABORT("fatal error");
 4044            }
 4045    }
 4046}
 4047
 4048// ggml_compute_forward_l2_norm
 4049
 4050static void ggml_compute_forward_l2_norm_f32(
 4051    const ggml_compute_params * params,
 4052    ggml_tensor * dst) {
 4053
 4054    const ggml_tensor * src0 = dst->src[0];
 4055
 4056    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 4057
 4058    GGML_ASSERT(src0->nb[0] == sizeof(float));
 4059
 4060    const int ith = params->ith;
 4061    const int nth = params->nth;
 4062
 4063    GGML_TENSOR_UNARY_OP_LOCALS
 4064
 4065    float eps;
 4066    memcpy(&eps, dst->op_params, sizeof(float));
 4067
 4068    GGML_ASSERT(eps >= 0.0f);
 4069
 4070    // TODO: optimize
 4071    for (int64_t i03 = 0; i03 < ne03; i03++) {
 4072        for (int64_t i02 = 0; i02 < ne02; i02++) {
 4073            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 4074                const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 4075
 4076                ggml_float sum = 0.0;
 4077                for (int64_t i00 = 0; i00 < ne00; i00++) {
 4078                    sum += (ggml_float)(x[i00] * x[i00]);
 4079                }
 4080
 4081                float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
 4082
 4083                memcpy(y, x, ne00 * sizeof(float));
 4084
 4085                const float scale = 1.0f/fmaxf(sqrtf(sum), eps);
 4086
 4087                ggml_vec_scale_f32(ne00, y, scale);
 4088            }
 4089        }
 4090    }
 4091}
 4092
 4093void ggml_compute_forward_l2_norm(
 4094    const ggml_compute_params * params,
 4095    ggml_tensor * dst) {
 4096
 4097    const ggml_tensor * src0 = dst->src[0];
 4098
 4099    switch (src0->type) {
 4100        case GGML_TYPE_F32:
 4101            {
 4102                ggml_compute_forward_l2_norm_f32(params, dst);
 4103            } break;
 4104        default:
 4105            {
 4106                GGML_ABORT("fatal error");
 4107            }
 4108    }
 4109}
 4110
 4111// ggml_compute_forward_out_prod
 4112
 4113static void ggml_compute_forward_out_prod_f32(
 4114        const ggml_compute_params * params,
 4115              ggml_tensor * dst) {
 4116
 4117    const ggml_tensor * src0 = dst->src[0];
 4118    const ggml_tensor * src1 = dst->src[1];
 4119
 4120    GGML_TENSOR_BINARY_OP_LOCALS
 4121
 4122    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 4123    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 4124    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 4125
 4126    const int ith = params->ith;
 4127    const int nth = params->nth;
 4128
 4129    GGML_ASSERT(ne0 == ne00);
 4130    GGML_ASSERT(ne1 == ne10);
 4131    GGML_ASSERT(ne2 == ne12);
 4132    GGML_ASSERT(ne3 == ne13);
 4133
 4134    GGML_ASSERT(ne2 % ne02 == 0);
 4135    GGML_ASSERT(ne3 % ne03 == 0);
 4136
 4137    // we don't support permuted src0 or src1
 4138    GGML_ASSERT(nb00 == sizeof(float));
 4139
 4140    // dst cannot be transposed or permuted
 4141    GGML_ASSERT(nb0 == sizeof(float));
 4142    // GGML_ASSERT(nb0 <= nb1);
 4143    // GGML_ASSERT(nb1 <= nb2);
 4144    // GGML_ASSERT(nb2 <= nb3);
 4145
 4146    // nb01 >= nb00 - src0 is not transposed
 4147    //   compute by src0 rows
 4148
 4149    if (ith == 0) {
 4150        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
 4151    }
 4152    ggml_barrier(params->threadpool);
 4153
 4154    // dst[:,:,:,:] = 0
 4155    // for i2,i3:
 4156    //   for i1:
 4157    //     for i01:
 4158    //       for i0:
 4159    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
 4160
 4161    // parallelize by last three dimensions
 4162
 4163    // total rows in dst
 4164    const int64_t nr = ne1*ne2*ne3;
 4165
 4166    // rows per thread
 4167    const int64_t dr = (nr + nth - 1)/nth;
 4168
 4169    // row range for this thread
 4170    const int64_t ir0 = dr*ith;
 4171    const int64_t ir1 = MIN(ir0 + dr, nr);
 4172
 4173    // block-tiling attempt
 4174    const int64_t blck_0 = MAX(GGML_VEC_MAD_UNROLL, 32);
 4175    const int64_t blck_1 = 16;
 4176
 4177    // dps == dst per src0, used for group query attention
 4178    const int64_t dps2 = ne2 / ne02;
 4179    const int64_t dps3 = ne3 / ne03;
 4180
 4181    for (int64_t bir = ir0; bir < ir1; bir += blck_1) {
 4182        const int64_t bir1 = MIN(bir + blck_1, ir1);
 4183        for (int64_t bi01 = 0; bi01 < ne01; bi01 += blck_0) {
 4184            const int64_t bne01 = MIN(bi01 + blck_0, ne01);
 4185            for (int64_t ir = bir; ir < bir1; ++ir) {
 4186                // dst indices
 4187                const int64_t i3 = ir/(ne2*ne1);
 4188                const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
 4189                const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
 4190
 4191                const int64_t i02 = i2 / dps2;
 4192                const int64_t i03 = i3 / dps3;
 4193
 4194                //const int64_t i10 = i1;
 4195                const int64_t i12 = i2;
 4196                const int64_t i13 = i3;
 4197
 4198#if GGML_VEC_MAD_UNROLL > 2
 4199                const int64_t bne01_unroll = bne01 - (bne01 % GGML_VEC_MAD_UNROLL);
 4200                for (int64_t i01 = bi01; i01 < bne01_unroll; i01 += GGML_VEC_MAD_UNROLL) {
 4201                    const int64_t i11 = i01;
 4202
 4203                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
 4204                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
 4205                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 4206
 4207                    ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
 4208                }
 4209                for (int64_t i01 = bne01_unroll; i01 < bne01; ++i01) {
 4210                    const int64_t i11 = i01;
 4211
 4212                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
 4213                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
 4214                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1   + i2*nb2   + i3*nb3));
 4215
 4216                    ggml_vec_mad_f32(ne0, d, s0, *s1);
 4217                }
 4218#else
 4219                for (int64_t i01 = bi01; i01 < bne01; ++i01) {
 4220                    const int64_t i11 = i01;
 4221
 4222                    float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
 4223                    float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
 4224                    float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
 4225
 4226                    ggml_vec_mad_f32(ne0, d, s0, *s1);
 4227                }
 4228#endif
 4229            }
 4230        }
 4231    }
 4232}
 4233
 4234static void ggml_compute_forward_out_prod_q_f32(
 4235        const ggml_compute_params * params,
 4236              ggml_tensor * dst) {
 4237
 4238    const ggml_tensor * src0 = dst->src[0];
 4239    const ggml_tensor * src1 = dst->src[1];
 4240
 4241    GGML_TENSOR_BINARY_OP_LOCALS;
 4242
 4243    const int ith = params->ith;
 4244    const int nth = params->nth;
 4245
 4246    const ggml_type type = src0->type;
 4247    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
 4248
 4249    GGML_ASSERT(ne02 == ne12);
 4250    GGML_ASSERT(ne03 == ne13);
 4251    GGML_ASSERT(ne2  == ne12);
 4252    GGML_ASSERT(ne3  == ne13);
 4253
 4254    // we don't support permuted src0 dim0
 4255    GGML_ASSERT(nb00 == ggml_type_size(type));
 4256
 4257    // dst dim0 cannot be transposed or permuted
 4258    GGML_ASSERT(nb0 == sizeof(float));
 4259    // GGML_ASSERT(nb0 <= nb1);
 4260    // GGML_ASSERT(nb1 <= nb2);
 4261    // GGML_ASSERT(nb2 <= nb3);
 4262
 4263    GGML_ASSERT(ne0 == ne00);
 4264    GGML_ASSERT(ne1 == ne10);
 4265    GGML_ASSERT(ne2 == ne02);
 4266    GGML_ASSERT(ne3 == ne03);
 4267
 4268    // nb01 >= nb00 - src0 is not transposed
 4269    //   compute by src0 rows
 4270
 4271    if (ith == 0) {
 4272        ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0);
 4273    }
 4274    ggml_barrier(params->threadpool);
 4275
 4276    // parallelize by last three dimensions
 4277
 4278    // total rows in dst
 4279    const int64_t nr = ne1*ne2*ne3;
 4280
 4281    // rows per thread
 4282    const int64_t dr = (nr + nth - 1)/nth;
 4283
 4284    // row range for this thread
 4285    const int64_t ir0 = dr*ith;
 4286    const int64_t ir1 = MIN(ir0 + dr, nr);
 4287
 4288    // dst[:,:,:,:] = 0
 4289    // for i2,i3:
 4290    //   for i1:
 4291    //     for i01:
 4292    //       for i0:
 4293    //         dst[i0,i1,i2,i3] += src0[i0,i01,i2,i3] * src1[i1,i01,i2,i3]
 4294
 4295    float * wdata = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32) * ith;
 4296
 4297    for (int64_t ir = ir0; ir < ir1; ++ir) {
 4298        // dst indices
 4299        const int64_t i3 = ir/(ne2*ne1);
 4300        const int64_t i2 = (ir - i3*ne2*ne1)/ne1;
 4301        const int64_t i1 = (ir - i3*ne2*ne1 - i2*ne1);
 4302
 4303        const int64_t i02 = i2;
 4304        const int64_t i03 = i3;
 4305
 4306        //const int64_t i10 = i1;
 4307        const int64_t i12 = i2;
 4308        const int64_t i13 = i3;
 4309
 4310        for (int64_t i01 = 0; i01 < ne01; ++i01) {
 4311            const int64_t i11 = i01;
 4312
 4313            float * s0 = (float *) ((char *) src0->data + (          i01*nb01 + i02*nb02 + i03*nb03));
 4314            float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
 4315            float * d  = (float *) ((char *)  dst->data + (          i1*nb1 + i2*nb2 + i3*nb3));
 4316
 4317            dequantize_row_q(s0, wdata, ne0);
 4318            ggml_vec_mad_f32(ne0, d, wdata, *s1);
 4319        }
 4320    }
 4321}
 4322
 4323void ggml_compute_forward_out_prod(
 4324        const ggml_compute_params * params,
 4325        ggml_tensor * dst) {
 4326
 4327    const ggml_tensor * src0 = dst->src[0];
 4328
 4329    switch (src0->type) {
 4330        case GGML_TYPE_Q4_0:
 4331        case GGML_TYPE_Q4_1:
 4332        case GGML_TYPE_Q5_0:
 4333        case GGML_TYPE_Q5_1:
 4334        case GGML_TYPE_Q8_0:
 4335        case GGML_TYPE_MXFP4:
 4336        case GGML_TYPE_Q2_K:
 4337        case GGML_TYPE_Q3_K:
 4338        case GGML_TYPE_Q4_K:
 4339        case GGML_TYPE_Q5_K:
 4340        case GGML_TYPE_Q6_K:
 4341        case GGML_TYPE_TQ1_0:
 4342        case GGML_TYPE_TQ2_0:
 4343        case GGML_TYPE_IQ2_XXS:
 4344        case GGML_TYPE_IQ2_XS:
 4345        case GGML_TYPE_IQ3_XXS:
 4346        case GGML_TYPE_IQ1_S:
 4347        case GGML_TYPE_IQ1_M:
 4348        case GGML_TYPE_IQ4_NL:
 4349        case GGML_TYPE_IQ4_XS:
 4350        case GGML_TYPE_IQ3_S:
 4351        case GGML_TYPE_IQ2_S:
 4352            {
 4353                ggml_compute_forward_out_prod_q_f32(params, dst);
 4354            } break;
 4355        case GGML_TYPE_F16:
 4356            {
 4357                GGML_ABORT("fatal error"); // todo
 4358                // ggml_compute_forward_out_prod_f16_f32(params, dst);
 4359            }
 4360        case GGML_TYPE_F32:
 4361            {
 4362                ggml_compute_forward_out_prod_f32(params, dst);
 4363            } break;
 4364        default:
 4365            {
 4366                GGML_ABORT("fatal error");
 4367            }
 4368    }
 4369}
 4370
 4371// ggml_compute_forward_scale
 4372
 4373static void ggml_compute_forward_scale_f32(
 4374        const ggml_compute_params * params,
 4375        ggml_tensor * dst) {
 4376
 4377    const ggml_tensor * src0 = dst->src[0];
 4378
 4379    GGML_ASSERT(ggml_is_contiguous(src0));
 4380    GGML_ASSERT(ggml_is_contiguous(dst));
 4381    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 4382
 4383    float s; // scale factor
 4384    float b; // bias
 4385
 4386    memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
 4387    memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
 4388
 4389    const int ith = params->ith;
 4390    const int nth = params->nth;
 4391
 4392    const int nc = src0->ne[0];
 4393    const int nr = ggml_nrows(src0);
 4394
 4395    // rows per thread
 4396    const int dr = (nr + nth - 1)/nth;
 4397
 4398    // row range for this thread
 4399    const int ir0 = dr*ith;
 4400    const int ir1 = MIN(ir0 + dr, nr);
 4401
 4402    const size_t nb01 = src0->nb[1];
 4403
 4404    const size_t nb1 = dst->nb[1];
 4405
 4406    if (b == 0.0f) {
 4407        for (int i1 = ir0; i1 < ir1; i1++) {
 4408            if (dst->data != src0->data) {
 4409                // src0 is same shape as dst => same indices
 4410                // TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
 4411                memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
 4412            }
 4413            ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
 4414        }
 4415    } else {
 4416        for (int i1 = ir0; i1 < ir1; i1++) {
 4417            ggml_vec_mad1_f32(nc,
 4418                (float *) ((char *) dst->data  + i1*nb1),
 4419                (float *) ((char *) src0->data + i1*nb1),
 4420                s, b);
 4421        }
 4422    }
 4423}
 4424
 4425void ggml_compute_forward_scale(
 4426        const ggml_compute_params * params,
 4427        ggml_tensor * dst) {
 4428
 4429    const ggml_tensor * src0 = dst->src[0];
 4430
 4431    switch (src0->type) {
 4432        case GGML_TYPE_F32:
 4433            {
 4434                ggml_compute_forward_scale_f32(params, dst);
 4435            } break;
 4436        default:
 4437            {
 4438                GGML_ABORT("fatal error");
 4439            }
 4440    }
 4441}
 4442
 4443// ggml_compute_forward_set
 4444
 4445static void ggml_compute_forward_set_f32(
 4446        const ggml_compute_params * params,
 4447        ggml_tensor * dst) {
 4448
 4449    const ggml_tensor * src0 = dst->src[0];
 4450    const ggml_tensor * src1 = dst->src[1];
 4451
 4452    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 4453    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 4454
 4455    // view src0 and dst with these strides and data offset inbytes during set
 4456    // nb0 is implicitly element_size because src0 and dst are contiguous
 4457    size_t nb1     = ((int32_t *) dst->op_params)[0];
 4458    size_t nb2     = ((int32_t *) dst->op_params)[1];
 4459    size_t nb3     = ((int32_t *) dst->op_params)[2];
 4460    size_t offset  = ((int32_t *) dst->op_params)[3];
 4461    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 4462
 4463    if (!inplace) {
 4464        if (params->ith == 0) {
 4465            // memcpy needs to be synchronized across threads to avoid race conditions.
 4466            // => do it in INIT phase
 4467            memcpy(
 4468                ((char *)  dst->data),
 4469                ((char *) src0->data),
 4470                ggml_nbytes(dst));
 4471        }
 4472        ggml_barrier(params->threadpool);
 4473    }
 4474
 4475    const int ith = params->ith;
 4476    const int nth = params->nth;
 4477
 4478    const int nr = ggml_nrows(src1);
 4479    const int nc = src1->ne[0];
 4480
 4481    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
 4482    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
 4483
 4484    // src0 and dst as viewed during set
 4485    const size_t nb0 = ggml_element_size(src0);
 4486
 4487    const int im0 = (ne10 == 0 ? 0 : ne10-1);
 4488    const int im1 = (ne11 == 0 ? 0 : ne11-1);
 4489    const int im2 = (ne12 == 0 ? 0 : ne12-1);
 4490    const int im3 = (ne13 == 0 ? 0 : ne13-1);
 4491
 4492    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
 4493
 4494    GGML_ASSERT(nb10 == sizeof(float));
 4495
 4496    // rows per thread
 4497    const int dr = (nr + nth - 1)/nth;
 4498
 4499    // row range for this thread
 4500    const int ir0 = dr*ith;
 4501    const int ir1 = MIN(ir0 + dr, nr);
 4502
 4503    for (int ir = ir0; ir < ir1; ++ir) {
 4504        // src0 and dst are viewed with shape of src1 and offset
 4505        // => same indices
 4506        const int i3 = ir/(ne12*ne11);
 4507        const int i2 = (ir - i3*ne12*ne11)/ne11;
 4508        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
 4509
 4510        ggml_vec_cpy_f32(nc,
 4511                (float *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
 4512                (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
 4513    }
 4514}
 4515
 4516static void ggml_compute_forward_set_i32(
 4517        const ggml_compute_params * params,
 4518        ggml_tensor * dst) {
 4519
 4520    const ggml_tensor * src0 = dst->src[0];
 4521    const ggml_tensor * src1 = dst->src[1];
 4522
 4523    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 4524    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 4525
 4526    // view src0 and dst with these strides and data offset inbytes during set
 4527    // nb0 is implicitly element_size because src0 and dst are contiguous
 4528    size_t nb1     = ((int32_t *) dst->op_params)[0];
 4529    size_t nb2     = ((int32_t *) dst->op_params)[1];
 4530    size_t nb3     = ((int32_t *) dst->op_params)[2];
 4531    size_t offset  = ((int32_t *) dst->op_params)[3];
 4532    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
 4533
 4534    if (!inplace) {
 4535        if (params->ith == 0) {
 4536            // memcpy needs to be synchronized across threads to avoid race conditions.
 4537            // => do it in INIT phase
 4538            memcpy(
 4539                ((char *)  dst->data),
 4540                ((char *) src0->data),
 4541                ggml_nbytes(dst));
 4542        }
 4543        ggml_barrier(params->threadpool);
 4544    }
 4545
 4546    const int ith = params->ith;
 4547    const int nth = params->nth;
 4548
 4549    const int nr = ggml_nrows(src1);
 4550    const int nc = src1->ne[0];
 4551
 4552    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
 4553    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
 4554
 4555    // src0 and dst as viewed during set
 4556    const size_t nb0 = ggml_element_size(src0);
 4557
 4558    const int im0 = (ne10 == 0 ? 0 : ne10-1);
 4559    const int im1 = (ne11 == 0 ? 0 : ne11-1);
 4560    const int im2 = (ne12 == 0 ? 0 : ne12-1);
 4561    const int im3 = (ne13 == 0 ? 0 : ne13-1);
 4562
 4563    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
 4564
 4565    GGML_ASSERT(nb10 == sizeof(int32_t));
 4566
 4567    // rows per thread
 4568    const int dr = (nr + nth - 1)/nth;
 4569
 4570    // row range for this thread
 4571    const int ir0 = dr*ith;
 4572    const int ir1 = MIN(ir0 + dr, nr);
 4573
 4574    for (int ir = ir0; ir < ir1; ++ir) {
 4575        // src0 and dst are viewed with shape of src1 and offset
 4576        // => same indices
 4577        const int i3 = ir/(ne12*ne11);
 4578        const int i2 = (ir - i3*ne12*ne11)/ne11;
 4579        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
 4580
 4581        ggml_vec_cpy_i32(nc,
 4582                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
 4583                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
 4584    }
 4585}
 4586
 4587void ggml_compute_forward_set(
 4588        const ggml_compute_params * params,
 4589        ggml_tensor * dst) {
 4590
 4591    const ggml_tensor * src0 = dst->src[0];
 4592
 4593    switch (src0->type) {
 4594        case GGML_TYPE_F32:
 4595            {
 4596                ggml_compute_forward_set_f32(params, dst);
 4597            } break;
 4598        case GGML_TYPE_I32:
 4599            {
 4600                ggml_compute_forward_set_i32(params, dst);
 4601            } break;
 4602        case GGML_TYPE_F16:
 4603        case GGML_TYPE_BF16:
 4604        case GGML_TYPE_Q4_0:
 4605        case GGML_TYPE_Q4_1:
 4606        case GGML_TYPE_Q5_0:
 4607        case GGML_TYPE_Q5_1:
 4608        case GGML_TYPE_Q8_0:
 4609        case GGML_TYPE_Q8_1:
 4610        case GGML_TYPE_MXFP4:
 4611        case GGML_TYPE_Q2_K:
 4612        case GGML_TYPE_Q3_K:
 4613        case GGML_TYPE_Q4_K:
 4614        case GGML_TYPE_Q5_K:
 4615        case GGML_TYPE_Q6_K:
 4616        case GGML_TYPE_TQ1_0:
 4617        case GGML_TYPE_TQ2_0:
 4618        case GGML_TYPE_IQ2_XXS:
 4619        case GGML_TYPE_IQ2_XS:
 4620        case GGML_TYPE_IQ3_XXS:
 4621        case GGML_TYPE_IQ1_S:
 4622        case GGML_TYPE_IQ1_M:
 4623        case GGML_TYPE_IQ4_NL:
 4624        case GGML_TYPE_IQ4_XS:
 4625        case GGML_TYPE_IQ3_S:
 4626        case GGML_TYPE_IQ2_S:
 4627        default:
 4628            {
 4629                GGML_ABORT("fatal error");
 4630            }
 4631    }
 4632}
 4633
 4634// ggml_compute_forward_cpy
 4635
 4636void ggml_compute_forward_cpy(
 4637        const ggml_compute_params * params,
 4638        ggml_tensor * dst) {
 4639    ggml_compute_forward_dup(params, dst);
 4640}
 4641
 4642// ggml_compute_forward_cont
 4643
 4644void ggml_compute_forward_cont(
 4645        const ggml_compute_params * params,
 4646        ggml_tensor * dst) {
 4647    ggml_compute_forward_dup(params, dst);
 4648}
 4649
 4650// ggml_compute_forward_get_rows
 4651
 4652static void ggml_compute_forward_get_rows_q(
 4653        const ggml_compute_params * params,
 4654              ggml_tensor * dst) {
 4655
 4656    const ggml_tensor * src0 = dst->src[0];
 4657    const ggml_tensor * src1 = dst->src[1];
 4658
 4659    GGML_TENSOR_BINARY_OP_LOCALS
 4660
 4661    const int64_t nc = ne00;
 4662    const int64_t nr = ggml_nelements(src1);
 4663
 4664    const ggml_type type = src0->type;
 4665    ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
 4666
 4667    assert(ne0  == nc);
 4668    assert(ne02 == ne11);
 4669    assert(nb00 == ggml_type_size(type));
 4670    assert(ggml_nrows(dst) == nr);
 4671
 4672    const int ith = params->ith;
 4673    const int nth = params->nth;
 4674
 4675    // rows per thread
 4676    const int dr = (nr + nth - 1)/nth;
 4677
 4678    // row range for this thread
 4679    const int ir0 = dr*ith;
 4680    const int ir1 = MIN(ir0 + dr, nr);
 4681
 4682    for (int64_t i = ir0; i < ir1; ++i) {
 4683        const int64_t i12 = i/(ne11*ne10);
 4684        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
 4685        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
 4686        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 4687
 4688        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 4689
 4690        dequantize_row_q(
 4691                (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
 4692                     (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
 4693    }
 4694}
 4695
 4696static void ggml_compute_forward_get_rows_f16(
 4697        const ggml_compute_params * params,
 4698              ggml_tensor * dst) {
 4699
 4700    const ggml_tensor * src0 = dst->src[0];
 4701    const ggml_tensor * src1 = dst->src[1];
 4702
 4703    GGML_TENSOR_BINARY_OP_LOCALS
 4704
 4705    const int64_t nc = ne00;
 4706    const int64_t nr = ggml_nelements(src1);
 4707
 4708    assert(ne0  == nc);
 4709    assert(ne02 == ne11);
 4710    assert(nb00 == sizeof(ggml_fp16_t));
 4711    assert(ggml_nrows(dst) == nr);
 4712
 4713    const int ith = params->ith;
 4714    const int nth = params->nth;
 4715
 4716    // rows per thread
 4717    const int dr = (nr + nth - 1)/nth;
 4718
 4719    // row range for this thread
 4720    const int ir0 = dr*ith;
 4721    const int ir1 = MIN(ir0 + dr, nr);
 4722
 4723    for (int64_t i = ir0; i < ir1; ++i) {
 4724        const int64_t i12 = i/(ne11*ne10);
 4725        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
 4726        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
 4727        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 4728
 4729        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 4730
 4731        ggml_cpu_fp16_to_fp32(
 4732            (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
 4733                       (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
 4734    }
 4735}
 4736
 4737static void ggml_compute_forward_get_rows_bf16(
 4738        const ggml_compute_params * params,
 4739              ggml_tensor * dst) {
 4740
 4741    const ggml_tensor * src0 = dst->src[0];
 4742    const ggml_tensor * src1 = dst->src[1];
 4743
 4744    GGML_TENSOR_BINARY_OP_LOCALS
 4745
 4746    const int64_t nc = ne00;
 4747    const int64_t nr = ggml_nelements(src1);
 4748
 4749    assert(ne0  == nc);
 4750    assert(ne02 == ne11);
 4751    assert(nb00 == sizeof(ggml_bf16_t));
 4752    assert(ggml_nrows(dst) == nr);
 4753
 4754    const int ith = params->ith;
 4755    const int nth = params->nth;
 4756
 4757    // rows per thread
 4758    const int dr = (nr + nth - 1)/nth;
 4759
 4760    // row range for this thread
 4761    const int ir0 = dr*ith;
 4762    const int ir1 = MIN(ir0 + dr, nr);
 4763
 4764    for (int64_t i = ir0; i < ir1; ++i) {
 4765        const int64_t i12 = i/(ne11*ne10);
 4766        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
 4767        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
 4768        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 4769
 4770        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 4771
 4772        ggml_cpu_bf16_to_fp32(
 4773            (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
 4774                        (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3), nc);
 4775    }
 4776}
 4777
 4778static void ggml_compute_forward_get_rows_f32(
 4779        const ggml_compute_params * params,
 4780              ggml_tensor * dst) {
 4781
 4782    const ggml_tensor * src0 = dst->src[0];
 4783    const ggml_tensor * src1 = dst->src[1];
 4784
 4785    GGML_TENSOR_BINARY_OP_LOCALS
 4786
 4787    const int64_t nc = ne00;
 4788    const int64_t nr = ggml_nelements(src1);
 4789
 4790    assert(ne0  == nc);
 4791    assert(ne02 == ne11);
 4792    assert(nb00 == sizeof(float));
 4793    assert(ggml_nrows(dst) == nr);
 4794
 4795    const int ith = params->ith;
 4796    const int nth = params->nth;
 4797
 4798    // rows per thread
 4799    const int dr = (nr + nth - 1)/nth;
 4800
 4801    // row range for this thread
 4802    const int ir0 = dr*ith;
 4803    const int ir1 = MIN(ir0 + dr, nr);
 4804
 4805    for (int64_t i = ir0; i < ir1; ++i) {
 4806        const int64_t i12 = i/(ne11*ne10);
 4807        const int64_t i11 = (i - i12*ne11*ne10)/ne10;
 4808        const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
 4809        const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 4810
 4811        GGML_ASSERT(i01 >= 0 && i01 < ne01);
 4812
 4813        ggml_vec_cpy_f32(nc,
 4814                (float *) ((char *)  dst->data + i10*nb1  + i11*nb2  + i12*nb3),
 4815                (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
 4816    }
 4817}
 4818
 4819void ggml_compute_forward_get_rows(
 4820        const ggml_compute_params * params,
 4821        ggml_tensor * dst) {
 4822
 4823    const ggml_tensor * src0 = dst->src[0];
 4824
 4825    switch (src0->type) {
 4826        case GGML_TYPE_Q4_0:
 4827        case GGML_TYPE_Q4_1:
 4828        case GGML_TYPE_Q5_0:
 4829        case GGML_TYPE_Q5_1:
 4830        case GGML_TYPE_Q8_0:
 4831        case GGML_TYPE_Q8_1:
 4832        case GGML_TYPE_MXFP4:
 4833        case GGML_TYPE_Q2_K:
 4834        case GGML_TYPE_Q3_K:
 4835        case GGML_TYPE_Q4_K:
 4836        case GGML_TYPE_Q5_K:
 4837        case GGML_TYPE_Q6_K:
 4838        case GGML_TYPE_TQ1_0:
 4839        case GGML_TYPE_TQ2_0:
 4840        case GGML_TYPE_IQ2_XXS:
 4841        case GGML_TYPE_IQ2_XS:
 4842        case GGML_TYPE_IQ3_XXS:
 4843        case GGML_TYPE_IQ1_S:
 4844        case GGML_TYPE_IQ1_M:
 4845        case GGML_TYPE_IQ4_NL:
 4846        case GGML_TYPE_IQ4_XS:
 4847        case GGML_TYPE_IQ3_S:
 4848        case GGML_TYPE_IQ2_S:
 4849            {
 4850                ggml_compute_forward_get_rows_q(params, dst);
 4851            } break;
 4852        case GGML_TYPE_F16:
 4853            {
 4854                ggml_compute_forward_get_rows_f16(params, dst);
 4855            } break;
 4856        case GGML_TYPE_BF16:
 4857            {
 4858                ggml_compute_forward_get_rows_bf16(params, dst);
 4859            } break;
 4860        case GGML_TYPE_F32:
 4861        case GGML_TYPE_I32:
 4862            {
 4863                ggml_compute_forward_get_rows_f32(params, dst);
 4864            } break;
 4865        default:
 4866            {
 4867                GGML_ABORT("fatal error");
 4868            }
 4869    }
 4870
 4871    //static bool first = true;
 4872    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
 4873    //if (first) {
 4874    //    first = false;
 4875    //} else {
 4876    //    for (int k = 0; k < dst->ne[1]; ++k) {
 4877    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
 4878    //            for (int i = 0; i < 16; ++i) {
 4879    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
 4880    //            }
 4881    //            printf("\n");
 4882    //        }
 4883    //        printf("\n");
 4884    //    }
 4885    //    printf("\n");
 4886    //    exit(0);
 4887    //}
 4888}
 4889
 4890template<typename idx_t>
 4891static void ggml_compute_forward_set_rows_f32(
 4892        const ggml_compute_params * params,
 4893              ggml_tensor * dst) {
 4894
 4895    const ggml_tensor * src0 = dst->src[0];
 4896    const ggml_tensor * src1 = dst->src[1];
 4897
 4898    GGML_TENSOR_BINARY_OP_LOCALS
 4899
 4900    const int64_t nc = ne00;
 4901    const int64_t nr = ne01;
 4902
 4903    assert(ne0  == nc);
 4904    assert(ne2  == ne02);
 4905    assert(ne3  == ne03);
 4906    assert(src0->type == GGML_TYPE_F32);
 4907    assert(ne02 % ne11 == 0);
 4908    assert(ne03 % ne12 == 0);
 4909
 4910    const int ith = params->ith;
 4911    const int nth = params->nth;
 4912
 4913    // rows per thread
 4914    const int64_t dr = (nr + nth - 1)/nth;
 4915
 4916    // row range for this thread
 4917    const int64_t ir0 = dr*ith;
 4918    const int64_t ir1 = std::min(ir0 + dr, nr);
 4919
 4920    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
 4921
 4922    for (int64_t i03 = 0; i03 < ne03; ++i03) {
 4923        for (int64_t i02 = 0; i02 < ne02; ++i02) {
 4924            for (int64_t i = ir0; i < ir1; ++i) {
 4925                const int64_t i12 = i03%ne12;
 4926                const int64_t i11 = i02%ne11;
 4927                const int64_t i10 = i;
 4928
 4929                const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 4930
 4931                GGML_ASSERT(i1 >= 0 && i1 < ne1);
 4932
 4933                from_float(
 4934                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),
 4935                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);
 4936            }
 4937        }
 4938    }
 4939}
 4940
 4941void ggml_compute_forward_set_rows(
 4942        const ggml_compute_params * params,
 4943        ggml_tensor * dst) {
 4944
 4945    const ggml_tensor * src0 = dst->src[0];
 4946    const ggml_tensor * src1 = dst->src[1];
 4947
 4948    switch (src0->type) {
 4949        case GGML_TYPE_F32:
 4950            {
 4951                if (src1->type == GGML_TYPE_I64) {
 4952                    ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
 4953                } else if (src1->type == GGML_TYPE_I32) {
 4954                    ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
 4955                } else {
 4956                    GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
 4957                }
 4958            } break;
 4959        default:
 4960            {
 4961                GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
 4962            }
 4963    }
 4964}
 4965
 4966// ggml_compute_forward_get_rows_back
 4967
 4968static void ggml_compute_forward_get_rows_back_f32_f16(
 4969        const ggml_compute_params * params,
 4970              ggml_tensor * dst) {
 4971
 4972    const ggml_tensor * src0 = dst->src[0];
 4973    const ggml_tensor * src1 = dst->src[1];
 4974
 4975    if (params->ith != 0) {
 4976        return;
 4977    }
 4978
 4979    GGML_ASSERT(ggml_is_contiguous(dst));
 4980
 4981    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
 4982
 4983    memset(dst->data, 0, ggml_nbytes(dst));
 4984
 4985    const int nc = src0->ne[0];
 4986    const int nr = ggml_nelements(src1);
 4987
 4988    GGML_ASSERT( dst->ne[0] == nc);
 4989    GGML_ASSERT(src0->nb[0] == sizeof(ggml_fp16_t));
 4990
 4991    for (int i = 0; i < nr; ++i) {
 4992        const int r = ((int32_t *) src1->data)[i];
 4993
 4994        for (int j = 0; j < nc; ++j) {
 4995            ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + i*src0->nb[1]))[j];
 4996            ((float *) ((char *) dst->data + r*dst->nb[1]))[j] += GGML_CPU_FP16_TO_FP32(v);
 4997        }
 4998    }
 4999}
 5000
 5001static void ggml_compute_forward_get_rows_back_f32(
 5002        const ggml_compute_params * params,
 5003              ggml_tensor * dst) {
 5004
 5005    const ggml_tensor * src0 = dst->src[0];
 5006    const ggml_tensor * src1 = dst->src[1];
 5007
 5008    if (params->ith != 0) {
 5009        return;
 5010    }
 5011
 5012    GGML_ASSERT(ggml_is_contiguous(dst));
 5013
 5014    // ggml_compute_forward_dup_same_cont(params, opt0, dst);
 5015
 5016    memset(dst->data, 0, ggml_nbytes(dst));
 5017
 5018    const int nc = src0->ne[0];
 5019    const int nr = ggml_nelements(src1);
 5020
 5021    GGML_ASSERT( dst->ne[0] == nc);
 5022    GGML_ASSERT(src0->nb[0] == sizeof(float));
 5023
 5024    for (int i = 0; i < nr; ++i) {
 5025        const int r = ((int32_t *) src1->data)[i];
 5026
 5027        ggml_vec_add_f32(nc,
 5028                (float *) ((char *)  dst->data + r*dst->nb[1]),
 5029                (float *) ((char *)  dst->data + r*dst->nb[1]),
 5030                (float *) ((char *) src0->data + i*src0->nb[1]));
 5031    }
 5032}
 5033
 5034void ggml_compute_forward_get_rows_back(
 5035        const ggml_compute_params * params,
 5036        ggml_tensor * dst) {
 5037
 5038    const ggml_tensor * src0 = dst->src[0];
 5039
 5040    switch (src0->type) {
 5041        case GGML_TYPE_F16:
 5042            {
 5043                ggml_compute_forward_get_rows_back_f32_f16(params, dst);
 5044            } break;
 5045        case GGML_TYPE_F32:
 5046            {
 5047                ggml_compute_forward_get_rows_back_f32(params, dst);
 5048            } break;
 5049        default:
 5050            {
 5051                GGML_ABORT("fatal error");
 5052            }
 5053    }
 5054
 5055    //static bool first = true;
 5056    //printf("ne0 = %d, ne1 = %d, ne2 = %d\n", dst->ne[0], dst->ne[1], dst->ne[2]);
 5057    //if (first) {
 5058    //    first = false;
 5059    //} else {
 5060    //    for (int k = 0; k < dst->ne[1]; ++k) {
 5061    //        for (int j = 0; j < dst->ne[0]/16; ++j) {
 5062    //            for (int i = 0; i < 16; ++i) {
 5063    //                printf("%8.4f ", ((float *) dst->data)[k*dst->ne[0] + j*16 + i]);
 5064    //            }
 5065    //            printf("\n");
 5066    //        }
 5067    //        printf("\n");
 5068    //    }
 5069    //    printf("\n");
 5070    //    exit(0);
 5071    //}
 5072}
 5073
 5074// ggml_compute_forward_diag
 5075
 5076static void ggml_compute_forward_diag_f32(
 5077        const ggml_compute_params * params,
 5078        ggml_tensor * dst) {
 5079
 5080    const ggml_tensor * src0 = dst->src[0];
 5081
 5082    if (params->ith != 0) {
 5083        return;
 5084    }
 5085
 5086    // TODO: handle transposed/permuted matrices
 5087
 5088    GGML_TENSOR_UNARY_OP_LOCALS
 5089
 5090    GGML_ASSERT(ne00 == ne0);
 5091    GGML_ASSERT(ne00 == ne1);
 5092    GGML_ASSERT(ne01 == 1);
 5093    GGML_ASSERT(ne02 == ne2);
 5094    GGML_ASSERT(ne03 == ne3);
 5095
 5096    GGML_ASSERT(nb00 == sizeof(float));
 5097    GGML_ASSERT(nb0  == sizeof(float));
 5098
 5099    for (int i3 = 0; i3 < ne3; i3++) {
 5100        for (int i2 = 0; i2 < ne2; i2++) {
 5101            for (int i1 = 0; i1 < ne1; i1++) {
 5102                float * d = (float *)((char *)  dst->data + i3*nb3  + i2*nb2 + i1*nb1);
 5103                float * s = (float *)((char *) src0->data + i3*nb03 + i2*nb02);
 5104                for (int i0 = 0; i0 < i1; i0++) {
 5105                    d[i0] = 0;
 5106                }
 5107                d[i1] = s[i1];
 5108                for (int i0 = i1+1; i0 < ne0; i0++) {
 5109                    d[i0] = 0;
 5110                }
 5111            }
 5112        }
 5113    }
 5114}
 5115
 5116void ggml_compute_forward_diag(
 5117        const ggml_compute_params * params,
 5118        ggml_tensor * dst) {
 5119
 5120    const ggml_tensor * src0 = dst->src[0];
 5121
 5122    switch (src0->type) {
 5123        case GGML_TYPE_F32:
 5124            {
 5125                ggml_compute_forward_diag_f32(params, dst);
 5126            } break;
 5127        default:
 5128            {
 5129                GGML_ABORT("fatal error");
 5130            }
 5131    }
 5132}
 5133
 5134// ggml_compute_forward_diag_mask_inf
 5135
 5136static void ggml_compute_forward_diag_mask_f32(
 5137        const ggml_compute_params * params,
 5138        ggml_tensor * dst,
 5139        const float value) {
 5140
 5141    const ggml_tensor * src0 = dst->src[0];
 5142
 5143    const int ith = params->ith;
 5144    const int nth = params->nth;
 5145
 5146    const int  n_past  = ((int32_t *) dst->op_params)[0];
 5147    const bool inplace = src0->data == dst->data;
 5148
 5149    GGML_ASSERT(n_past >= 0);
 5150
 5151    if (!inplace) {
 5152        if (ith == 0) {
 5153            // memcpy needs to be synchronized across threads to avoid race conditions.
 5154            // => do it in INIT phase
 5155            GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
 5156            GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
 5157            memcpy(
 5158                ((char *)  dst->data),
 5159                ((char *) src0->data),
 5160                ggml_nbytes(dst));
 5161        }
 5162        ggml_barrier(params->threadpool);
 5163    }
 5164
 5165    // TODO: handle transposed/permuted matrices
 5166
 5167    const int n  = ggml_nrows(src0);
 5168    const int nc = src0->ne[0];
 5169    const int nr = src0->ne[1];
 5170    const int nz = n/nr;
 5171
 5172    GGML_ASSERT( dst->nb[0] == sizeof(float));
 5173    GGML_ASSERT(src0->nb[0] == sizeof(float));
 5174
 5175    for (int k = 0; k < nz; k++) {
 5176        for (int j = ith; j < nr; j += nth) {
 5177            for (int i = n_past; i < nc; i++) {
 5178                if (i > n_past + j) {
 5179                    *(float *)((char *) dst->data + k*dst->nb[2] + j*dst->nb[1] + i*dst->nb[0]) = value;
 5180                }
 5181            }
 5182        }
 5183    }
 5184}
 5185
 5186void ggml_compute_forward_diag_mask_inf(
 5187        const ggml_compute_params * params,
 5188        ggml_tensor * dst) {
 5189
 5190    const ggml_tensor * src0 = dst->src[0];
 5191
 5192    switch (src0->type) {
 5193        case GGML_TYPE_F32:
 5194            {
 5195                ggml_compute_forward_diag_mask_f32(params, dst, -INFINITY);
 5196            } break;
 5197        default:
 5198            {
 5199                GGML_ABORT("fatal error");
 5200            }
 5201    }
 5202}
 5203
 5204void ggml_compute_forward_diag_mask_zero(
 5205        const ggml_compute_params * params,
 5206        ggml_tensor * dst) {
 5207
 5208    const ggml_tensor * src0 = dst->src[0];
 5209
 5210    switch (src0->type) {
 5211        case GGML_TYPE_F32:
 5212            {
 5213                ggml_compute_forward_diag_mask_f32(params, dst, 0);
 5214            } break;
 5215        default:
 5216            {
 5217                GGML_ABORT("fatal error");
 5218            }
 5219    }
 5220}
 5221
 5222// ggml_compute_forward_soft_max
 5223
 5224static void ggml_compute_forward_soft_max_f32(
 5225        const ggml_compute_params * params,
 5226              ggml_tensor * dst) {
 5227
 5228    const ggml_tensor * src0 = dst->src[0];
 5229    const ggml_tensor * src1 = dst->src[1];
 5230    const ggml_tensor * src2 = dst->src[2];
 5231
 5232    assert(ggml_is_contiguous(dst));
 5233    assert(ggml_are_same_shape(src0, dst));
 5234
 5235    float scale    = 1.0f;
 5236    float max_bias = 0.0f;
 5237
 5238    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
 5239    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 5240
 5241    const int ith = params->ith;
 5242    const int nth = params->nth;
 5243
 5244    GGML_TENSOR_UNARY_OP_LOCALS
 5245
 5246    const int64_t nb11 = src1 ? src1->nb[1] : 1;
 5247    const int64_t nb12 = src1 ? src1->nb[2] : 1;
 5248    const int64_t nb13 = src1 ? src1->nb[3] : 1;
 5249
 5250    const int64_t ne12 = src1 ? src1->ne[2] : 1;
 5251    const int64_t ne13 = src1 ? src1->ne[3] : 1;
 5252
 5253    // TODO: is this supposed to be ceil instead of floor?
 5254    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
 5255    const uint32_t n_head      = ne02;
 5256    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
 5257
 5258    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 5259    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 5260
 5261    float * wp = (float *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 5262
 5263    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 5264
 5265    // sinks
 5266    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
 5267
 5268    for (int64_t i03 = 0; i03 < ne03; i03++) {
 5269        for (int64_t i02 = 0; i02 < ne02; i02++) {
 5270            for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
 5271                const int64_t i11 = i01;
 5272                const int64_t i12 = i02%ne12;
 5273                const int64_t i13 = i03%ne13;
 5274
 5275                // ALiBi
 5276                const uint32_t h = i02; // head
 5277                const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
 5278
 5279                float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
 5280                float * dp = (float *)((char *)  dst->data + i01*nb1  + i02*nb2  + i03*nb3);
 5281
 5282                // broadcast the mask across rows
 5283                ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
 5284                float       * mp_f32 = src1 ? (float       *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
 5285
 5286                ggml_vec_cpy_f32  (ne00, wp, sp);
 5287                ggml_vec_scale_f32(ne00, wp, scale);
 5288                if (mp_f32) {
 5289                    if (use_f16) {
 5290                        for (int i = 0; i < ne00; ++i) {
 5291                            wp[i] += slope*GGML_CPU_FP16_TO_FP32(mp_f16[i]);
 5292                        }
 5293                    } else {
 5294                        for (int i = 0; i < ne00; ++i) {
 5295                            wp[i] += slope*mp_f32[i];
 5296                        }
 5297                    }
 5298                }
 5299
 5300#ifndef NDEBUG
 5301                for (int i = 0; i < ne00; ++i) {
 5302                    //printf("p[%d] = %f\n", i, p[i]);
 5303                    assert(!isnan(wp[i]));
 5304                }
 5305#endif
 5306
 5307                float max = -INFINITY;
 5308                ggml_vec_max_f32(ne00, &max, wp);
 5309
 5310                // if we have sinks, make a correction as if they were included in the softmax
 5311                if (sk) {
 5312                    max = MAX(max, sk[i02]);
 5313                }
 5314
 5315                ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
 5316                assert(sum > 0.0);
 5317
 5318                if (sk) {
 5319                    sum += (ggml_float) expf(sk[i02] - max);
 5320                }
 5321
 5322                sum = 1.0/sum;
 5323                ggml_vec_scale_f32(ne00, dp, sum);
 5324
 5325#ifndef NDEBUG
 5326                for (int i = 0; i < ne00; ++i) {
 5327                    assert(!isnan(dp[i]));
 5328                    assert(!isinf(dp[i]));
 5329                }
 5330#endif
 5331            }
 5332        }
 5333    }
 5334}
 5335
 5336void ggml_compute_forward_soft_max(
 5337        const ggml_compute_params * params,
 5338              ggml_tensor * dst) {
 5339
 5340    const ggml_tensor * src0 = dst->src[0];
 5341
 5342    switch (src0->type) {
 5343        case GGML_TYPE_F32:
 5344            {
 5345                ggml_compute_forward_soft_max_f32(params, dst);
 5346            } break;
 5347        default:
 5348            {
 5349                GGML_ABORT("fatal error");
 5350            }
 5351    }
 5352}
 5353
 5354
 5355// ggml_compute_forward_soft_max_ext_back
 5356
 5357static void ggml_compute_forward_soft_max_ext_back_f32(
 5358        const ggml_compute_params * params,
 5359        ggml_tensor * dst) {
 5360
 5361    const ggml_tensor * src0 = dst->src[0];
 5362    const ggml_tensor * src1 = dst->src[1];
 5363
 5364    GGML_ASSERT(ggml_is_contiguous(src0));
 5365    GGML_ASSERT(ggml_is_contiguous(src1));
 5366    GGML_ASSERT(ggml_is_contiguous(dst));
 5367    GGML_ASSERT(ggml_are_same_shape(src0, dst));
 5368    GGML_ASSERT(ggml_are_same_shape(src1, dst));
 5369
 5370    float scale    = 1.0f;
 5371    float max_bias = 0.0f;
 5372
 5373    memcpy(&scale,    (const float *) dst->op_params + 0, sizeof(float));
 5374    memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
 5375
 5376    GGML_ASSERT(max_bias == 0.0f);
 5377
 5378    // TODO: handle transposed/permuted matrices
 5379
 5380    const int ith = params->ith;
 5381    const int nth = params->nth;
 5382
 5383    const int nc = src0->ne[0];
 5384    const int nr = ggml_nrows(src0);
 5385
 5386    // rows per thread
 5387    const int dr = (nr + nth - 1)/nth;
 5388
 5389    // row range for this thread
 5390    const int ir0 = dr*ith;
 5391    const int ir1 = MIN(ir0 + dr, nr);
 5392
 5393    for (int i1 = ir0; i1 < ir1; i1++) {
 5394        float *dy = (float *)((char *) src0->data + i1*src0->nb[1]);
 5395        float *y  = (float *)((char *) src1->data + i1*src1->nb[1]);
 5396        float *dx = (float *)((char *) dst->data  + i1*dst->nb[1]);
 5397
 5398#ifndef NDEBUG
 5399        for (int i = 0; i < nc; ++i) {
 5400            //printf("p[%d] = %f\n", i, p[i]);
 5401            assert(!isnan(dy[i]));
 5402            assert(!isnan(y[i]));
 5403        }
 5404#endif
 5405        // Jii = yi - yi*yi
 5406        // Jij = -yi*yj
 5407        // J = diag(y)-y.T*y
 5408        // dx = J * dy
 5409        // dxk = sum_i(Jki * dyi)
 5410        // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
 5411        // dxk = sum_i(-yk*yi * dyi) + yk*yk*dyk + yk*dyk - yk*yk*dyk
 5412        // dxk = sum_i(-yk*yi * dyi) + yk*dyk
 5413        // dxk = -yk * sum_i(yi * dyi) + yk*dyk
 5414        // dxk = -yk * dot(y, dy) + yk*dyk
 5415        // dxk = yk * (- dot(y, dy) + dyk)
 5416        // dxk = yk * (dyk - dot(y, dy))
 5417        //
 5418        // post-order:
 5419        // dot_y_dy := dot(y, dy)
 5420        // dx := dy
 5421        // dx := dx - dot_y_dy
 5422        // dx := dx * y
 5423
 5424        // linear runtime, no additional memory
 5425        float dot_y_dy = 0;
 5426        ggml_vec_dot_f32  (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
 5427        ggml_vec_cpy_f32  (nc, dx, dy);
 5428        ggml_vec_acc1_f32 (nc, dx, -dot_y_dy);
 5429        ggml_vec_mul_f32  (nc, dx, dx, y);
 5430        ggml_vec_scale_f32(nc, dx, scale);
 5431
 5432#ifndef NDEBUG
 5433        for (int i = 0; i < nc; ++i) {
 5434            assert(!isnan(dx[i]));
 5435            assert(!isinf(dx[i]));
 5436        }
 5437#endif
 5438    }
 5439}
 5440
 5441void ggml_compute_forward_soft_max_ext_back(
 5442        const ggml_compute_params * params,
 5443        ggml_tensor * dst) {
 5444
 5445    const ggml_tensor * src0 = dst->src[0];
 5446
 5447    switch (src0->type) {
 5448        case GGML_TYPE_F32:
 5449            {
 5450                ggml_compute_forward_soft_max_ext_back_f32(params, dst);
 5451            } break;
 5452        default:
 5453            {
 5454                GGML_ABORT("fatal error");
 5455            }
 5456    }
 5457}
 5458
 5459// ggml_compute_forward_clamp
 5460
 5461static void ggml_compute_forward_clamp_f32(
 5462        const ggml_compute_params * params,
 5463        ggml_tensor * dst) {
 5464
 5465    const ggml_tensor * src0 = dst->src[0];
 5466
 5467    float min;
 5468    float max;
 5469    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
 5470    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 5471
 5472    const int ith = params->ith;
 5473    const int nth = params->nth;
 5474
 5475    const int n  = ggml_nrows(src0);
 5476    const int nc = src0->ne[0];
 5477
 5478    const size_t nb00 = src0->nb[0];
 5479    const size_t nb01 = src0->nb[1];
 5480
 5481    const size_t nb0 = dst->nb[0];
 5482    const size_t nb1 = dst->nb[1];
 5483
 5484    GGML_ASSERT( nb0 == sizeof(float));
 5485    GGML_ASSERT(nb00 == sizeof(float));
 5486
 5487    for (int j = ith; j < n; j += nth) {
 5488        float * dst_ptr  = (float *) ((char *)  dst->data + j*nb1);
 5489        float * src0_ptr = (float *) ((char *) src0->data + j*nb01);
 5490
 5491        for (int i = 0; i < nc; i++) {
 5492            dst_ptr[i] = MAX(MIN(src0_ptr[i], max), min);
 5493        }
 5494    }
 5495}
 5496
 5497static void ggml_compute_forward_clamp_f16(
 5498    const ggml_compute_params * params,
 5499    ggml_tensor * dst) {
 5500
 5501    const ggml_tensor * src0 = dst->src[0];
 5502
 5503    float min;
 5504    float max;
 5505    memcpy(&min, (float *) dst->op_params + 0, sizeof(float));
 5506    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 5507
 5508    const int ith = params->ith;
 5509    const int nth = params->nth;
 5510
 5511    const int n  = ggml_nrows(src0);
 5512    const int nc = src0->ne[0];
 5513
 5514    const size_t nb00 = src0->nb[0];
 5515    const size_t nb01 = src0->nb[1];
 5516
 5517    const size_t nb0 = dst->nb[0];
 5518    const size_t nb1 = dst->nb[1];
 5519
 5520    GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
 5521    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 5522
 5523    for (int j = ith; j < n; j += nth) {
 5524        ggml_fp16_t * dst_ptr  = (ggml_fp16_t *) ((char *)  dst->data + j*nb1);
 5525        ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
 5526
 5527        for (int i = 0; i < nc; i++) {
 5528            float v = GGML_CPU_FP16_TO_FP32(src0_ptr[i]);
 5529            dst_ptr[i] = GGML_CPU_FP32_TO_FP16(MAX(MIN(v, max), min));
 5530        }
 5531    }
 5532}
 5533
 5534void ggml_compute_forward_clamp(
 5535        const ggml_compute_params * params,
 5536        ggml_tensor * dst) {
 5537
 5538    const ggml_tensor * src0 = dst->src[0];
 5539
 5540    switch (src0->type) {
 5541        case GGML_TYPE_F32:
 5542            {
 5543                ggml_compute_forward_clamp_f32(params, dst);
 5544            } break;
 5545        case GGML_TYPE_F16:
 5546            {
 5547                ggml_compute_forward_clamp_f16(params, dst);
 5548            } break;
 5549        case GGML_TYPE_BF16:
 5550        case GGML_TYPE_Q4_0:
 5551        case GGML_TYPE_Q4_1:
 5552        case GGML_TYPE_Q5_0:
 5553        case GGML_TYPE_Q5_1:
 5554        case GGML_TYPE_Q8_0:
 5555        case GGML_TYPE_Q8_1:
 5556        case GGML_TYPE_MXFP4:
 5557        case GGML_TYPE_Q2_K:
 5558        case GGML_TYPE_Q3_K:
 5559        case GGML_TYPE_Q4_K:
 5560        case GGML_TYPE_Q5_K:
 5561        case GGML_TYPE_Q6_K:
 5562        case GGML_TYPE_TQ1_0:
 5563        case GGML_TYPE_TQ2_0:
 5564        case GGML_TYPE_IQ2_XXS:
 5565        case GGML_TYPE_IQ2_XS:
 5566        case GGML_TYPE_IQ3_XXS:
 5567        case GGML_TYPE_IQ1_S:
 5568        case GGML_TYPE_IQ1_M:
 5569        case GGML_TYPE_IQ4_NL:
 5570        case GGML_TYPE_IQ4_XS:
 5571        case GGML_TYPE_IQ3_S:
 5572        case GGML_TYPE_IQ2_S:
 5573        case GGML_TYPE_Q8_K:
 5574        case GGML_TYPE_I8:
 5575        case GGML_TYPE_I16:
 5576        case GGML_TYPE_I32:
 5577        case GGML_TYPE_I64:
 5578        case GGML_TYPE_F64:
 5579        case GGML_TYPE_COUNT:
 5580            {
 5581                GGML_ABORT("fatal error");
 5582            }
 5583    }
 5584}
 5585
 5586// ggml_compute_forward_rope
 5587
 5588static float rope_yarn_ramp(const float low, const float high, const int i0) {
 5589    const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
 5590    return 1 - MIN(1, MAX(0, y));
 5591}
 5592
 5593// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 5594// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 5595static void rope_yarn(
 5596    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
 5597    float * cos_theta, float * sin_theta) {
 5598    // Get n-d rotational scaling corrected for extrapolation
 5599    float theta_interp = freq_scale * theta_extrap;
 5600    float theta = theta_interp;
 5601    if (ext_factor != 0.0f) {
 5602        float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
 5603        theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
 5604
 5605        // Get n-d magnitude scaling corrected for interpolation
 5606        mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
 5607    }
 5608    *cos_theta = cosf(theta) * mscale;
 5609    *sin_theta = sinf(theta) * mscale;
 5610}
 5611
 5612static void ggml_rope_cache_init(
 5613     float theta_base, float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
 5614     float * cache, float sin_sign, float theta_scale) {
 5615    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
 5616    float theta = theta_base;
 5617    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 5618        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
 5619        rope_yarn(
 5620            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
 5621        );
 5622        cache[i0 + 1] *= sin_sign;
 5623
 5624        theta *= theta_scale;
 5625    }
 5626}
 5627
 5628static void ggml_mrope_cache_init(
 5629     float theta_base_t, float theta_base_h, float theta_base_w, float theta_base_e, int sections[4], bool is_imrope, bool indep_sects,
 5630     float freq_scale, const float * freq_factors, float corr_dims[2], int64_t ne0, float ext_factor, float mscale,
 5631     float * cache, float sin_sign, float theta_scale) {
 5632    // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
 5633    float theta_t = theta_base_t;
 5634    float theta_h = theta_base_h;
 5635    float theta_w = theta_base_w;
 5636    float theta_e = theta_base_e;  // extra position id for vision encoder
 5637    int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
 5638    int sec_w = sections[1] + sections[0];
 5639    int sec_e = sections[2] + sec_w;
 5640    GGML_ASSERT(sect_dims <= ne0);
 5641
 5642    for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
 5643        const float ff = freq_factors ? freq_factors[i0/2] : 1.0f;
 5644
 5645        int sector = (i0 / 2) % sect_dims;
 5646        if (indep_sects) {
 5647            // compute theta independently for each dim sections
 5648            // (i.e. reset corresponding theta when `i0` go from one section to another)
 5649            if (sector == 0) {
 5650                theta_t = theta_base_t;
 5651            }
 5652            else if (sector == sections[0]) {
 5653                theta_h = theta_base_h;;
 5654            }
 5655            else if (sector == sec_w) {
 5656                theta_w = theta_base_w;
 5657            }
 5658            else if (sector == sec_e) {
 5659                theta_e = theta_base_e;
 5660            }
 5661        }
 5662
 5663        float theta = theta_t;
 5664        if (is_imrope) { // qwen3vl apply interleaved mrope
 5665            if (sector % 3 == 1 && sector < 3 * sections[1]) {
 5666                theta = theta_h;
 5667            } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
 5668                theta = theta_w;
 5669            } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
 5670                theta = theta_t;
 5671            } else {
 5672                theta = theta_e;
 5673            }
 5674        } else {
 5675            if (sector >= sections[0] && sector < sec_w) {
 5676                theta = theta_h;
 5677            }
 5678            else if (sector >= sec_w && sector < sec_w + sections[2]) {
 5679                theta = theta_w;
 5680            }
 5681            else if (sector >= sec_w + sections[2]) {
 5682                theta = theta_e;
 5683            }
 5684        }
 5685
 5686        rope_yarn(
 5687            theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1]
 5688        );
 5689        cache[i0 + 1] *= sin_sign;
 5690
 5691        theta_t *= theta_scale;
 5692        theta_w *= theta_scale;
 5693        theta_h *= theta_scale;
 5694        theta_e *= theta_scale;
 5695    }
 5696}
 5697
 5698
 5699template<typename T>
 5700static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
 5701  for (int64_t i0 = 0; i0 < n; i0 += 2) {
 5702    const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
 5703
 5704    const float cos_theta = cache[i0 + 0];
 5705    const float sin_theta = cache[i0 + 1];
 5706
 5707    const T * const src = src_data + ic;
 5708    T * dst             = dst_data + ic;
 5709
 5710    const float x0 = type_conversion_table<T>::to_f32(src[0]);
 5711    const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
 5712
 5713    dst[0]        = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
 5714    dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
 5715  }
 5716}
 5717
 5718template<typename T> //float or ggml_fp16_t
 5719static void ggml_compute_forward_rope_flt(
 5720        const ggml_compute_params * params,
 5721        ggml_tensor * dst,
 5722        const bool forward) {
 5723
 5724    const ggml_tensor * src0 = dst->src[0];
 5725    const ggml_tensor * src1 = dst->src[1];
 5726    const ggml_tensor * src2 = dst->src[2];
 5727
 5728    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
 5729    GGML_ASSERT(src1->type == GGML_TYPE_I32);
 5730
 5731    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 5732    int sections[4];
 5733
 5734    //const int n_past     = ((int32_t *) dst->op_params)[0];
 5735    const int n_dims     = ((int32_t *) dst->op_params)[1];
 5736    const int mode       = ((int32_t *) dst->op_params)[2];
 5737    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
 5738    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 5739
 5740    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
 5741    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
 5742    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
 5743    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
 5744    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
 5745    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 5746    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
 5747
 5748    GGML_TENSOR_UNARY_OP_LOCALS
 5749
 5750    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
 5751    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 5752
 5753    GGML_ASSERT(nb0 == nb00);
 5754    GGML_ASSERT(nb0 == sizeof(T));
 5755
 5756    const int ith = params->ith;
 5757    const int nth = params->nth;
 5758
 5759    const int nr = ggml_nrows(dst);
 5760
 5761    GGML_ASSERT(n_dims <= ne0);
 5762    GGML_ASSERT(n_dims % 2 == 0);
 5763
 5764    // rows per thread
 5765    const int dr = (nr + nth - 1)/nth;
 5766
 5767    // row range for this thread
 5768    const int ir0 = dr*ith;
 5769    const int ir1 = MIN(ir0 + dr, nr);
 5770
 5771    // row index used to determine which thread to use
 5772    int ir = 0;
 5773
 5774    const float theta_scale = powf(freq_base, -2.0f/n_dims);
 5775
 5776    float corr_dims[2];
 5777    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 5778
 5779    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
 5780    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
 5781    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 5782
 5783    if (mrope_used) {
 5784        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
 5785    }
 5786
 5787    if (is_vision) {
 5788        GGML_ASSERT(n_dims == ne0/2);
 5789    }
 5790
 5791    const float * freq_factors = NULL;
 5792    if (src2 != NULL) {
 5793        GGML_ASSERT(src2->type == GGML_TYPE_F32);
 5794        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
 5795        freq_factors = (const float *) src2->data;
 5796    }
 5797
 5798    // backward process uses inverse rotation by cos and sin.
 5799    // cos and sin build a rotation matrix, where the inverse is the transpose.
 5800    // this essentially just switches the sign of sin.
 5801    const float sin_sign = forward ? 1.0f : -1.0f;
 5802
 5803    const int32_t * pos = (const int32_t *) src1->data;
 5804
 5805    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
 5806        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
 5807
 5808            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
 5809            if (!mrope_used) {
 5810                const int64_t p = pos[i2];
 5811                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 5812            }
 5813            else {
 5814                const int64_t p_t = pos[i2];
 5815                const int64_t p_h = pos[i2 + ne2];
 5816                const int64_t p_w = pos[i2 + ne2 * 2];
 5817                const int64_t p_e = pos[i2 + ne2 * 3];
 5818                ggml_mrope_cache_init(
 5819                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
 5820                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
 5821            }
 5822
 5823            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
 5824                if (ir++ < ir0) continue;
 5825                if (ir   > ir1) break;
 5826
 5827                T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
 5828                T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);
 5829
 5830                switch (mode) {
 5831                    case GGML_ROPE_TYPE_NORMAL:
 5832                        rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
 5833                        break;
 5834                    case GGML_ROPE_TYPE_NEOX:
 5835                    case GGML_ROPE_TYPE_MROPE:
 5836                    case GGML_ROPE_TYPE_IMROPE:
 5837                        rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
 5838                        break;
 5839                    case GGML_ROPE_TYPE_VISION:
 5840                        rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
 5841                        break;
 5842                    default:
 5843                        GGML_ABORT("rope type not supported");
 5844                }
 5845
 5846                if (!is_vision) {
 5847                    // fill the remain channels with data from src tensor
 5848                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
 5849                        const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
 5850                        T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 5851
 5852                        dst_data[0] = src[0];
 5853                        dst_data[1] = src[1];
 5854                    }
 5855                }
 5856            } //attn-heads
 5857        }
 5858    }
 5859}
 5860
 5861void ggml_compute_forward_rope(
 5862        const ggml_compute_params * params,
 5863        ggml_tensor * dst) {
 5864
 5865    const ggml_tensor * src0 = dst->src[0];
 5866
 5867    switch (src0->type) {
 5868        case GGML_TYPE_F16:
 5869            {
 5870                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
 5871            } break;
 5872        case GGML_TYPE_F32:
 5873            {
 5874                ggml_compute_forward_rope_flt<float>(params, dst, true);
 5875            } break;
 5876        default:
 5877            {
 5878                GGML_ABORT("fatal error");
 5879            }
 5880    }
 5881}
 5882
 5883// ggml_compute_forward_rope_back
 5884
 5885void ggml_compute_forward_rope_back(
 5886        const ggml_compute_params * params,
 5887        ggml_tensor * dst) {
 5888
 5889    const ggml_tensor * src0 = dst->src[0];
 5890
 5891    switch (src0->type) {
 5892        case GGML_TYPE_F16:
 5893            {
 5894                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
 5895            } break;
 5896        case GGML_TYPE_F32:
 5897            {
 5898                ggml_compute_forward_rope_flt<float>(params, dst, false);
 5899            } break;
 5900        default:
 5901            {
 5902                GGML_ABORT("fatal error");
 5903            }
 5904    }
 5905}
 5906
 5907// ggml_compute_forward_conv_transpose_1d
 5908
 5909static void ggml_compute_forward_conv_transpose_1d_f16_f32(
 5910        const ggml_compute_params * params,
 5911              ggml_tensor * dst) {
 5912
 5913    const ggml_tensor * src0 = dst->src[0];
 5914    const ggml_tensor * src1 = dst->src[1];
 5915
 5916    GGML_ASSERT(src0->type == GGML_TYPE_F16);
 5917    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 5918    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 5919
 5920    GGML_TENSOR_BINARY_OP_LOCALS
 5921
 5922    const int ith = params->ith;
 5923    const int nth = params->nth;
 5924
 5925    const int nk = ne00*ne01*ne02;
 5926
 5927    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 5928    GGML_ASSERT(nb10 == sizeof(float));
 5929
 5930    if (ith == 0) {
 5931        memset(params->wdata, 0, params->wsize);
 5932
 5933        // permute kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
 5934        {
 5935            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 5936
 5937            for (int64_t i02 = 0; i02 < ne02; i02++) {
 5938                for (int64_t i01 = 0; i01 < ne01; i01++) {
 5939                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i02*nb02 + i01*nb01);
 5940                    ggml_fp16_t * dst_data = wdata + i01*ne00*ne02;
 5941                    for (int64_t i00 = 0; i00 < ne00; i00++) {
 5942                        dst_data[i00*ne02 + i02] = src[i00];
 5943                    }
 5944                }
 5945            }
 5946        }
 5947
 5948        // permute source data (src1) from (L x Cin) to (Cin x L)
 5949        {
 5950            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
 5951            ggml_fp16_t * dst_data = wdata;
 5952
 5953            for (int64_t i11 = 0; i11 < ne11; i11++) {
 5954                const float * const src = (float *)((char *) src1->data + i11*nb11);
 5955                for (int64_t i10 = 0; i10 < ne10; i10++) {
 5956                    dst_data[i10*ne11 + i11] = GGML_CPU_FP32_TO_FP16(src[i10]);
 5957                }
 5958            }
 5959        }
 5960
 5961        // need to zero dst since we are accumulating into it
 5962        memset(dst->data, 0, ggml_nbytes(dst));
 5963    }
 5964    ggml_barrier(params->threadpool);
 5965
 5966    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 5967
 5968    // total rows in dst
 5969    const int nr = ne1;
 5970
 5971    // rows per thread
 5972    const int dr = (nr + nth - 1)/nth;
 5973
 5974    // row range for this thread
 5975    const int ir0 = dr*ith;
 5976    const int ir1 = MIN(ir0 + dr, nr);
 5977
 5978    ggml_fp16_t * const wdata     = (ggml_fp16_t *) params->wdata + 0;
 5979    ggml_fp16_t * const wdata_src = wdata + nk;
 5980
 5981    for (int i1 = ir0; i1 < ir1; i1++) {
 5982        float * dst_data = (float *)((char *) dst->data + i1*nb1);
 5983        ggml_fp16_t * wdata_kernel = wdata + i1*ne02*ne00;
 5984        for (int i10 = 0; i10 < ne10; i10++) {
 5985            const int i1n = i10*ne11;
 5986            for (int i00 = 0; i00 < ne00; i00++) {
 5987                float v = 0;
 5988                ggml_vec_dot_f16(ne02, &v, 0,
 5989                        (ggml_fp16_t *)    wdata_src + i1n, 0,
 5990                        (ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
 5991                dst_data[i10*s0 + i00] += v;
 5992            }
 5993        }
 5994    }
 5995}
 5996
 5997static void ggml_compute_forward_conv_transpose_1d_f32(
 5998        const ggml_compute_params * params,
 5999              ggml_tensor * dst) {
 6000
 6001    const ggml_tensor * src0 = dst->src[0];
 6002    const ggml_tensor * src1 = dst->src[1];
 6003
 6004    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 6005    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6006    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 6007
 6008    GGML_TENSOR_BINARY_OP_LOCALS
 6009
 6010    const int ith = params->ith;
 6011    const int nth = params->nth;
 6012
 6013    const int nk = ne00*ne01*ne02;
 6014
 6015    GGML_ASSERT(nb00 == sizeof(float));
 6016    GGML_ASSERT(nb10 == sizeof(float));
 6017
 6018    if (ith == 0) {
 6019        memset(params->wdata, 0, params->wsize);
 6020
 6021        // prepare kernel data (src0) from (K x Cout x Cin) to (Cin x K x Cout)
 6022        {
 6023            float * const wdata = (float *) params->wdata + 0;
 6024
 6025            for (int64_t i02 = 0; i02 < ne02; i02++) {
 6026                for (int64_t i01 = 0; i01 < ne01; i01++) {
 6027                    const float * const src = (float *)((char *) src0->data + i02*nb02 + i01*nb01);
 6028                    float * dst_data = wdata + i01*ne00*ne02;
 6029                    for (int64_t i00 = 0; i00 < ne00; i00++) {
 6030                        dst_data[i00*ne02 + i02] = src[i00];
 6031                    }
 6032                }
 6033            }
 6034        }
 6035
 6036        // prepare source data (src1)
 6037        {
 6038            float * const wdata = (float *) params->wdata + nk;
 6039            float * dst_data = wdata;
 6040
 6041            for (int64_t i11 = 0; i11 < ne11; i11++) {
 6042                const float * const src = (float *)((char *) src1->data + i11*nb11);
 6043                for (int64_t i10 = 0; i10 < ne10; i10++) {
 6044                    dst_data[i10*ne11 + i11] = src[i10];
 6045                }
 6046            }
 6047        }
 6048
 6049        // need to zero dst since we are accumulating into it
 6050        memset(dst->data, 0, ggml_nbytes(dst));
 6051    }
 6052    ggml_barrier(params->threadpool);
 6053
 6054    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
 6055
 6056    // total rows in dst
 6057    const int nr = ne1;
 6058
 6059    // rows per thread
 6060    const int dr = (nr + nth - 1)/nth;
 6061
 6062    // row range for this thread
 6063    const int ir0 = dr*ith;
 6064    const int ir1 = MIN(ir0 + dr, nr);
 6065
 6066    float * const wdata     = (float *) params->wdata + 0;
 6067    float * const wdata_src = wdata + nk;
 6068
 6069    for (int i1 = ir0; i1 < ir1; i1++) {
 6070        float * dst_data = (float *)((char *) dst->data + i1*nb1);
 6071        float * wdata_kernel = wdata + i1*ne02*ne00;
 6072        for (int i10 = 0; i10 < ne10; i10++) {
 6073            const int i1n = i10*ne11;
 6074            for (int i00 = 0; i00 < ne00; i00++) {
 6075                float v = 0;
 6076                ggml_vec_dot_f32(ne02, &v, 0,
 6077                        wdata_src + i1n, 0,
 6078                        wdata_kernel + i00*ne02, 0, 1);
 6079                dst_data[i10*s0 + i00] += v;
 6080            }
 6081        }
 6082    }
 6083}
 6084
 6085void ggml_compute_forward_conv_transpose_1d(
 6086        const ggml_compute_params * params,
 6087              ggml_tensor * dst) {
 6088
 6089    const ggml_tensor * src0 = dst->src[0];
 6090
 6091    switch (src0->type) {
 6092        case GGML_TYPE_F16:
 6093            {
 6094                ggml_compute_forward_conv_transpose_1d_f16_f32(params, dst);
 6095            } break;
 6096        case GGML_TYPE_F32:
 6097            {
 6098                ggml_compute_forward_conv_transpose_1d_f32(params, dst);
 6099            } break;
 6100        default:
 6101            {
 6102                GGML_ABORT("fatal error");
 6103            }
 6104    }
 6105}
 6106
 6107// ggml_compute_forward_im2col_f32
 6108// src0: kernel [OC, IC, KH, KW]
 6109// src1: image [N, IC, IH, IW]
 6110// dst:  result [N, OH, OW, IC*KH*KW]
 6111static void ggml_compute_forward_im2col_f32(
 6112        const ggml_compute_params * params,
 6113              ggml_tensor * dst) {
 6114
 6115    const ggml_tensor * src0 = dst->src[0];
 6116    const ggml_tensor * src1 = dst->src[1];
 6117
 6118    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6119    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 6120
 6121    GGML_TENSOR_BINARY_OP_LOCALS;
 6122
 6123    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
 6124    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
 6125    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
 6126    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
 6127    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
 6128    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
 6129    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 6130
 6131    const int ith = params->ith;
 6132    const int nth = params->nth;
 6133
 6134    const int64_t N  = is_2D ? ne13 : ne12;
 6135    const int64_t IC = is_2D ? ne12 : ne11;
 6136    const int64_t IH = is_2D ? ne11 : 1;
 6137    const int64_t IW = ne10;
 6138
 6139    const int64_t KH = is_2D ? ne01 : 1;
 6140    const int64_t KW = ne00;
 6141
 6142    const int64_t OH = is_2D ? ne2 : 1;
 6143    const int64_t OW = ne1;
 6144
 6145    int ofs0 = is_2D ? nb13 : nb12;
 6146    int ofs1 = is_2D ? nb12 : nb11;
 6147
 6148    GGML_ASSERT(nb10 == sizeof(float));
 6149
 6150    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
 6151    {
 6152        float * const wdata = (float *) dst->data;
 6153
 6154        for (int64_t in = 0; in < N; in++) {
 6155            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
 6156                for (int64_t iow = 0; iow < OW; iow++) {
 6157                    for (int64_t iic = ith; iic < IC; iic += nth) {
 6158
 6159                        // micro kernel
 6160                        float * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
 6161                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
 6162
 6163                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
 6164                            for (int64_t ikw = 0; ikw < KW; ikw++) {
 6165                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
 6166                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
 6167
 6168                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
 6169                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
 6170                                } else {
 6171                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = (src_data[iih*IW + iiw]);
 6172                                }
 6173                            }
 6174                        }
 6175                    }
 6176                }
 6177            }
 6178        }
 6179    }
 6180}
 6181
 6182
 6183// ggml_compute_forward_im2col_f16
 6184// src0: kernel [OC, IC, KH, KW]
 6185// src1: image [N, IC, IH, IW]
 6186// dst:  result [N, OH, OW, IC*KH*KW]
 6187static void ggml_compute_forward_im2col_f16(
 6188        const ggml_compute_params * params,
 6189              ggml_tensor * dst) {
 6190
 6191    const ggml_tensor * src0 = dst->src[0];
 6192    const ggml_tensor * src1 = dst->src[1];
 6193
 6194    GGML_ASSERT(src0->type == GGML_TYPE_F16);
 6195    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6196    GGML_ASSERT( dst->type == GGML_TYPE_F16);
 6197
 6198    GGML_TENSOR_BINARY_OP_LOCALS;
 6199
 6200    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
 6201    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
 6202    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
 6203    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
 6204    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
 6205    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
 6206    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 6207
 6208    const int ith = params->ith;
 6209    const int nth = params->nth;
 6210
 6211    const int64_t N  = is_2D ? ne13 : ne12;
 6212    const int64_t IC = is_2D ? ne12 : ne11;
 6213    const int64_t IH = is_2D ? ne11 : 1;
 6214    const int64_t IW = ne10;
 6215
 6216    const int64_t KH = is_2D ? ne01 : 1;
 6217    const int64_t KW = ne00;
 6218
 6219    const int64_t OH = is_2D ? ne2 : 1;
 6220    const int64_t OW = ne1;
 6221
 6222    int ofs0 = is_2D ? nb13 : nb12;
 6223    int ofs1 = is_2D ? nb12 : nb11;
 6224
 6225    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 6226    GGML_ASSERT(nb10 == sizeof(float));
 6227
 6228    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
 6229    {
 6230        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
 6231
 6232        for (int64_t in = 0; in < N; in++) {
 6233            for (int64_t ioh = 0; ioh < OH; ioh++) { // 1
 6234                for (int64_t iow = 0; iow < OW; iow++) {
 6235                    for (int64_t iic = ith; iic < IC; iic += nth) {
 6236
 6237                        // micro kernel
 6238                        ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
 6239                        const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW]
 6240
 6241                        for (int64_t ikh = 0; ikh < KH; ikh++) {  // 1
 6242                            for (int64_t ikw = 0; ikw < KW; ikw++) {
 6243                                const int64_t iiw = iow*s0 + ikw*d0 - p0;
 6244                                const int64_t iih = ioh*s1 + ikh*d1 - p1;
 6245
 6246                                if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
 6247                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0;
 6248                                } else {
 6249                                    dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]);
 6250                                }
 6251                            }
 6252                        }
 6253                    }
 6254                }
 6255            }
 6256        }
 6257    }
 6258}
 6259
 6260void ggml_compute_forward_im2col(
 6261        const ggml_compute_params * params,
 6262              ggml_tensor * dst) {
 6263    switch (dst->type) {
 6264        case GGML_TYPE_F16:
 6265            {
 6266                ggml_compute_forward_im2col_f16(params, dst);
 6267            } break;
 6268        case GGML_TYPE_F32:
 6269            {
 6270                ggml_compute_forward_im2col_f32(params, dst);
 6271            } break;
 6272        default:
 6273            {
 6274                GGML_ABORT("fatal error");
 6275            }
 6276    }
 6277}
 6278
 6279// ggml_compute_forward_im2col_back_f32
 6280
 6281void ggml_compute_forward_im2col_back_f32(
 6282        const ggml_compute_params * params,
 6283              ggml_tensor * dst) {
 6284
 6285    const ggml_tensor * src0 = dst->src[0]; // gradients of forward pass output
 6286    const ggml_tensor * src1 = dst->src[1]; // convolution kernel
 6287
 6288    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 6289    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6290    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 6291
 6292    GGML_TENSOR_BINARY_OP_LOCALS;
 6293
 6294    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
 6295    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
 6296    const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
 6297    const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
 6298    const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
 6299    const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
 6300    const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
 6301
 6302    const int ith = params->ith;
 6303    const int nth = params->nth;
 6304
 6305    const int64_t N  = is_2D ? ne3 : ne2;
 6306    const int64_t IC = is_2D ? ne2 : ne1;
 6307    const int64_t IH = is_2D ? ne1 : 1;
 6308    const int64_t IW = ne0;
 6309
 6310    const int64_t KH = is_2D ? ne11 : 1;
 6311    const int64_t KW = ne10;
 6312
 6313    const int64_t OH = is_2D ? ne02 : 1;
 6314    const int64_t OW = ne01;
 6315
 6316    int ofs0 = is_2D ? nb3 : nb2;
 6317    int ofs1 = is_2D ? nb2 : nb1;
 6318
 6319    GGML_ASSERT(nb0  == sizeof(float));
 6320
 6321    // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
 6322    {
 6323        float * const wdata = (float *) dst->data;
 6324
 6325        for (int64_t in = 0; in < N; in++) {
 6326            for (int64_t iic = ith; iic < IC; iic += nth) {
 6327                for (int64_t iih = 0; iih < IH; iih++) {
 6328                    for (int64_t iiw = 0; iiw < IW; iiw++) {
 6329
 6330                        // micro kernel
 6331                        float grad = 0.0f;
 6332                        for (int64_t ikh = 0; ikh < KH; ikh++) {
 6333                            for (int64_t ikw = 0; ikw < KW; ikw++) {
 6334                                // For s0 > 1 some values were skipped over in the forward pass.
 6335                                // These values have tmpw % s0 != 0 and need to be skipped in the backwards pass as well.
 6336                                const int64_t tmpw = (iiw + p0 - ikw*d0);
 6337                                if (tmpw % s0 != 0) {
 6338                                    continue;
 6339                                }
 6340                                const int64_t iow = tmpw / s0;
 6341
 6342                                // Equivalent logic as above except for s1.
 6343                                int64_t ioh;
 6344                                if (is_2D) {
 6345                                    const int64_t tmph = iih + p1 - ikh*d1;
 6346
 6347                                    if (tmph % s1 != 0) {
 6348                                        continue;
 6349                                    }
 6350
 6351                                    ioh = tmph / s1;
 6352                                } else {
 6353                                    ioh = 0;
 6354                                }
 6355
 6356                                if (iow < 0 || iow >= OW || ioh < 0 || ioh >= OH) {
 6357                                    continue;
 6358                                }
 6359
 6360                                const float * const grad_in = (const float *) src0->data
 6361                                    + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW]
 6362                                grad += grad_in[iic*(KH*KW) + ikh*KW + ikw];
 6363                            }
 6364                        }
 6365                        float * dst_data = (float *)((char *) wdata + (in*ofs0 + iic*ofs1)); // [IH, IW]
 6366                        dst_data[iih*IW + iiw] = grad;
 6367                    }
 6368                }
 6369            }
 6370        }
 6371    }
 6372}
 6373
 6374
 6375// ggml_compute_forward_im2col_3d_f16
 6376// src0: kernel [OC*IC, KD, KH, KW]
 6377// src1: image [N*IC, ID, IH, IW]
 6378// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
 6379static void ggml_compute_forward_im2col_3d_f16(
 6380        const ggml_compute_params * params,
 6381              ggml_tensor * dst) {
 6382
 6383    const ggml_tensor * src0 = dst->src[0];
 6384    const ggml_tensor * src1 = dst->src[1];
 6385
 6386    GGML_ASSERT(src0->type == GGML_TYPE_F16);
 6387    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6388    GGML_ASSERT( dst->type == GGML_TYPE_F16);
 6389
 6390    GGML_TENSOR_BINARY_OP_LOCALS;
 6391
 6392    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
 6393    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
 6394    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
 6395    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
 6396    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
 6397    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
 6398    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
 6399    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
 6400    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
 6401    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
 6402
 6403
 6404    const int ith = params->ith;
 6405    const int nth = params->nth;
 6406
 6407    const int64_t N  = ne13 / IC;
 6408    const int64_t ID = ne12;
 6409    const int64_t IH = ne11;
 6410    const int64_t IW = ne10;
 6411
 6412    const int64_t OC = ne03 / IC;
 6413    GGML_UNUSED(OC);
 6414    const int64_t KD = ne02;
 6415    const int64_t KH = ne01;
 6416    const int64_t KW = ne00;
 6417
 6418    const int64_t OD = ne3 / N;
 6419    const int64_t OH = ne2;
 6420    const int64_t OW = ne1;
 6421    const int64_t OH_OW = OH*OW;
 6422    const int64_t KD_KH_KW = KD*KH*KW;
 6423    const int64_t KH_KW = KH*KW;
 6424    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
 6425
 6426    GGML_ASSERT(nb10 == sizeof(float));
 6427
 6428    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
 6429    {
 6430        ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
 6431
 6432        for (int64_t in = 0; in < N; in++) {
 6433            for (int64_t iod = 0; iod < OD; iod++) {
 6434                for (int64_t ioh = 0; ioh < OH; ioh++) {
 6435                    for (int64_t iow = 0; iow < OW; iow++) {
 6436                        for (int64_t iic = ith; iic < IC; iic += nth) {
 6437
 6438                            // micro kernel
 6439                            ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
 6440                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
 6441
 6442                            for (int64_t ikd = 0; ikd < KD; ikd++) {
 6443                                for (int64_t ikh = 0; ikh < KH; ikh++) {
 6444                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
 6445                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
 6446                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
 6447                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
 6448
 6449                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
 6450                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
 6451                                        } else {
 6452                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
 6453                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
 6454                                        }
 6455                                    }
 6456                                }
 6457                            }
 6458                        }
 6459                    }
 6460                }
 6461            }
 6462        }
 6463    }
 6464}
 6465
 6466// ggml_compute_forward_im2col_3d_f32
 6467// src0: kernel [OC*IC, KD, KH, KW]
 6468// src1: image [N*IC, ID, IH, IW]
 6469// dst:  result [N*OD, OH, OW, IC * KD * KH * KW]
 6470static void ggml_compute_forward_im2col_3d_f32(
 6471        const ggml_compute_params * params,
 6472              ggml_tensor * dst) {
 6473
 6474    const ggml_tensor * src0 = dst->src[0];
 6475    const ggml_tensor * src1 = dst->src[1];
 6476
 6477    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6478    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 6479
 6480    GGML_TENSOR_BINARY_OP_LOCALS;
 6481
 6482    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
 6483    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
 6484    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
 6485    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
 6486    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
 6487    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
 6488    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
 6489    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
 6490    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
 6491    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
 6492
 6493
 6494    const int ith = params->ith;
 6495    const int nth = params->nth;
 6496
 6497    const int64_t N  = ne13 / IC;
 6498    const int64_t ID = ne12;
 6499    const int64_t IH = ne11;
 6500    const int64_t IW = ne10;
 6501
 6502    const int64_t OC = ne03 / IC;
 6503    GGML_UNUSED(OC);
 6504    const int64_t KD = ne02;
 6505    const int64_t KH = ne01;
 6506    const int64_t KW = ne00;
 6507
 6508    const int64_t OD = ne3 / N;
 6509    const int64_t OH = ne2;
 6510    const int64_t OW = ne1;
 6511
 6512    const int64_t OH_OW = OH*OW;
 6513    const int64_t KD_KH_KW = KD*KH*KW;
 6514    const int64_t KH_KW = KH*KW;
 6515    const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
 6516
 6517    GGML_ASSERT(nb10 == sizeof(float));
 6518
 6519    // im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
 6520    {
 6521        float * const wdata = (float *) dst->data;
 6522
 6523        for (int64_t in = 0; in < N; in++) {
 6524            for (int64_t iod = 0; iod < OD; iod++) {
 6525                for (int64_t ioh = 0; ioh < OH; ioh++) {
 6526                    for (int64_t iow = 0; iow < OW; iow++) {
 6527                        for (int64_t iic = ith; iic < IC; iic += nth) {
 6528
 6529                            // micro kernel
 6530                            float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
 6531                            const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
 6532
 6533                            for (int64_t ikd = 0; ikd < KD; ikd++) {
 6534                                for (int64_t ikh = 0; ikh < KH; ikh++) {
 6535                                    for (int64_t ikw = 0; ikw < KW; ikw++) {
 6536                                        const int64_t iiw = iow*s0 + ikw*d0 - p0;
 6537                                        const int64_t iih = ioh*s1 + ikh*d1 - p1;
 6538                                        const int64_t iid = iod*s2 + ikd*d2 - p2;
 6539
 6540                                        if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
 6541                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
 6542                                        } else {
 6543                                            const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
 6544                                            dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
 6545                                        }
 6546                                    }
 6547                                }
 6548                            }
 6549                        }
 6550                    }
 6551                }
 6552            }
 6553        }
 6554    }
 6555}
 6556
 6557
 6558void ggml_compute_forward_im2col_3d(
 6559        const ggml_compute_params * params,
 6560              ggml_tensor * dst) {
 6561    switch (dst->type) {
 6562        case GGML_TYPE_F16:
 6563            {
 6564                ggml_compute_forward_im2col_3d_f16(params, dst);
 6565            } break;
 6566        case GGML_TYPE_F32:
 6567            {
 6568                ggml_compute_forward_im2col_3d_f32(params, dst);
 6569            } break;
 6570        default:
 6571            {
 6572                GGML_ABORT("fatal error");
 6573            }
 6574    }
 6575}
 6576
 6577static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
 6578                              void * a, void * b, float * c) {
 6579    const ggml_type_traits * traits = ggml_get_type_traits(type);
 6580    struct ggml_tensor src1 = {};
 6581    src1.type  = type;
 6582    src1.ne[0] = k;
 6583    src1.ne[1] = m;
 6584    src1.ne[2] = 1;
 6585    src1.ne[3] = 1;
 6586    src1.nb[0] = traits->type_size;
 6587    src1.nb[1] = k * traits->type_size;
 6588    src1.nb[2] = src1.nb[1];
 6589    src1.nb[3] = src1.nb[2];
 6590    src1.data  = a;
 6591
 6592    struct ggml_tensor src0 = {};
 6593    src0.type  = type;
 6594    src0.ne[0] = k;
 6595    src0.ne[1] = n;
 6596    src0.ne[2] = 1;
 6597    src0.ne[3] = 1;
 6598    src0.nb[0] = traits->type_size;
 6599    src0.nb[1] = k * traits->type_size;
 6600    src0.nb[2] = src0.nb[1];
 6601    src0.nb[3] = src0.nb[2];
 6602    src0.data  = b;
 6603
 6604    struct ggml_tensor dst = {};
 6605    dst.ne[0] = n;
 6606    dst.ne[1] = m;
 6607    dst.ne[2] = 1;
 6608    dst.ne[3] = 1;
 6609    dst.nb[0] = sizeof(float);
 6610    dst.nb[1] = n * sizeof(float);
 6611    dst.nb[2] = dst.nb[1];
 6612    dst.nb[3] = dst.nb[2];
 6613    dst.data  = c;
 6614    dst.src[0] = &src0;
 6615    dst.src[1] = &src1;
 6616
 6617    ggml_compute_forward_mul_mat(params, &dst);
 6618}
 6619
 6620static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) {
 6621    return (coord  + size) % size; // adding size avoids negative number weirdness
 6622}
 6623
 6624// ggml_compute_forward_conv_2d
 6625
 6626
 6627static void ggml_compute_forward_conv_2d_impl(const ggml_compute_params * params,
 6628                                              const ggml_tensor *         kernel,  // [KW, KH, IC, OC]
 6629                                              const ggml_tensor *         src,     // [W, H, C, N]
 6630                                              ggml_tensor *               dst,     // [OW, OH, OC, N]
 6631                                              ggml_type                   kernel_type) {
 6632
 6633    GGML_ASSERT(ggml_is_contiguous(kernel));
 6634    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
 6635    GGML_ASSERT(kernel->type == kernel_type);
 6636
 6637    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
 6638
 6639    const int32_t stride_x   = dst->op_params[0];
 6640    const int32_t stride_y   = dst->op_params[1];
 6641    const int32_t pad_x      = dst->op_params[2];
 6642    const int32_t pad_y      = dst->op_params[3];
 6643    const int32_t dilation_x = dst->op_params[4];
 6644    const int32_t dilation_y = dst->op_params[5];
 6645
 6646    const int64_t c_in  = src->ne[2];
 6647    const int64_t c_out = kernel->ne[3];
 6648    GGML_ASSERT(c_in == kernel->ne[2]);
 6649
 6650    const int64_t src_w = src->ne[0];
 6651    const int64_t src_h = src->ne[1];
 6652    const int64_t knl_w = kernel->ne[0];
 6653    const int64_t knl_h = kernel->ne[1];
 6654    const int64_t dst_w = dst->ne[0];
 6655    const int64_t dst_h = dst->ne[1];
 6656
 6657    const float * src_data = (float *) src->data;
 6658    void  * knl_data       = kernel->data;
 6659    float * dst_data       = (float *) dst->data;
 6660
 6661    const int64_t knl_n           = knl_w * knl_h * c_in;
 6662    const int64_t patch_total     = dst->ne[3] * dst_w * dst_h;
 6663
 6664    const int64_t space_per_patch   = knl_n * traits->type_size + c_out * sizeof(float);
 6665    const int64_t batch_size        = params->wsize / space_per_patch;
 6666    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
 6667    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
 6668
 6669    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
 6670
 6671    void * tmp = params->wdata;
 6672
 6673    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
 6674
 6675        const int64_t patch_start_batch = batch_i * patches_per_batch;
 6676        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch,
 6677                                              patch_total);
 6678        const int64_t patch_n           = patch_end_batch - patch_start_batch;
 6679
 6680        const int64_t patch_per_thread  = (patch_n + params->nth - 1) / params->nth;
 6681        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
 6682        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
 6683
 6684        //im2col for a patch
 6685        for (int64_t p = patch_start; p < patch_end; ++p) {
 6686            const int64_t  batch_n     =  p / (dst_w * dst_h);
 6687            const int64_t  src_x       = (p / dst_w) % dst_h;
 6688            const int64_t  src_y       =  p % dst_w;
 6689
 6690            const float * src_base = (const float *)((const char *)src_data + batch_n * src->nb[3]);
 6691            char *        dst_row  = (char *) tmp + (p % patches_per_batch) * knl_n * traits->type_size;
 6692
 6693            for (int64_t ic = 0; ic < c_in; ++ic) {
 6694                for (int64_t ky = 0; ky < knl_h; ++ky) {
 6695                    for (int64_t kx = 0; kx < knl_w; ++kx) {
 6696                        const int64_t sy = src_x * stride_y + ky * dilation_y - pad_y;
 6697                        const int64_t sx = src_y * stride_x + kx * dilation_x - pad_x;
 6698
 6699                        int64_t dst_idx = ic * (knl_h * knl_w) + ky * knl_w + kx;
 6700
 6701                        float src_val;
 6702                        if (sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
 6703                            src_val = 0.0f;
 6704                        } else {
 6705                            const float * src_ptr = (const float *)((const char *)src_base + sx * src->nb[0] + sy * src->nb[1] + ic * src->nb[2]);
 6706                            src_val               = *src_ptr;
 6707                        }
 6708
 6709                        char * element_ptr = dst_row + dst_idx * traits->type_size;
 6710                        if (kernel_type == GGML_TYPE_F32) {
 6711                            *(float *) element_ptr = src_val;
 6712                        } else if (kernel_type == GGML_TYPE_F16) {
 6713                            *(ggml_fp16_t *) element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
 6714                        }
 6715                    }
 6716                }
 6717            }
 6718        }   // patches handled by this thread
 6719
 6720        ggml_barrier(params->threadpool);
 6721
 6722        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n * traits->type_size);
 6723
 6724        GGML_ASSERT(gemm_output + patch_n * c_out <= (float*)tmp + params->wsize);
 6725
 6726        // GEMM: patches[patch_n, knl_n] ร— kernel[knl_n, c_out] = output[patch_n, c_out]
 6727        ggml_call_mul_mat(kernel_type, params, patch_n, c_out, knl_n, tmp, knl_data, gemm_output);
 6728
 6729        ggml_barrier(params->threadpool);
 6730
 6731
 6732        //permute back [OC, N, OH, OW] to [N, OC, OH, OW]
 6733        const int64_t permute_per_thread = (patch_n + params->nth - 1) / params->nth;
 6734        const int64_t permute_start = params->ith * permute_per_thread;
 6735        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n);
 6736
 6737        for (int64_t i = permute_start; i < permute_end; ++i) {
 6738            const int64_t p       = patch_start_batch + i;
 6739            const int64_t batch_n = p / (dst_w * dst_h);
 6740            const int64_t dst_y   = (p / dst_w) % dst_h;
 6741            const int64_t dst_x   = p % dst_w;
 6742
 6743            for (int64_t oc = 0; oc < c_out; ++oc) {
 6744                const float value = gemm_output[i * c_out + oc];
 6745                float * dst_ptr = (float *)((char *)dst_data + dst_x * dst->nb[0] + dst_y * dst->nb[1] + oc * dst->nb[2] + batch_n * dst->nb[3]);
 6746                *dst_ptr = value;
 6747            }
 6748        }
 6749    }
 6750}
 6751
 6752void ggml_compute_forward_conv_2d(
 6753        const ggml_compute_params * params,
 6754        ggml_tensor * dst) {
 6755
 6756    const ggml_tensor * src0 = dst->src[0];
 6757    const ggml_tensor * src1 = dst->src[1];
 6758
 6759    ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
 6760}
 6761
 6762// ggml_compute_forward_conv_3d
 6763
 6764static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
 6765                                              const ggml_tensor *         kernel,
 6766                                              const ggml_tensor *         src,
 6767                                              ggml_tensor *               dst,
 6768                                              ggml_type                   kernel_type) {
 6769
 6770    GGML_ASSERT(ggml_is_contiguous(kernel));
 6771    GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
 6772    GGML_ASSERT(kernel->type == kernel_type);
 6773
 6774    const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
 6775
 6776    const int32_t s0 = dst->op_params[0];
 6777    const int32_t s1 = dst->op_params[1];
 6778    const int32_t s2 = dst->op_params[2];
 6779    const int32_t p0 = dst->op_params[3];
 6780    const int32_t p1 = dst->op_params[4];
 6781    const int32_t p2 = dst->op_params[5];
 6782    const int32_t d0 = dst->op_params[6];
 6783    const int32_t d1 = dst->op_params[7];
 6784    const int32_t d2 = dst->op_params[8];
 6785    const int32_t c  = dst->op_params[9];
 6786    const int32_t n  = dst->op_params[10];
 6787    const int32_t oc = dst->op_params[11];
 6788
 6789    const int64_t src_w = src->ne[0];
 6790    const int64_t src_h = src->ne[1];
 6791    const int64_t src_d = src->ne[2];
 6792    const int64_t knl_w = kernel->ne[0];
 6793    const int64_t knl_h = kernel->ne[1];
 6794    const int64_t knl_d = kernel->ne[2];
 6795    const int64_t dst_w = dst->ne[0];
 6796    const int64_t dst_h = dst->ne[1];
 6797    const int64_t dst_d = dst->ne[2];
 6798
 6799    const float * src_data = (float *) src->data;
 6800    void  * knl_data       = kernel->data;
 6801    float * dst_data       = (float *) dst->data;
 6802
 6803    const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
 6804    const int64_t knl_n_total       = knl_n_per_channel * c;
 6805    const int64_t patch_total       = n * dst_w * dst_h * dst_d;
 6806
 6807    const int64_t space_per_patch   = knl_n_total * traits->type_size + oc * sizeof(float);
 6808    const int64_t batch_size        = params->wsize / space_per_patch;
 6809    const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
 6810    const int64_t batch_n           = (patch_total + patches_per_batch - 1) / patches_per_batch;
 6811
 6812    GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
 6813
 6814    void * tmp = params->wdata;
 6815
 6816    for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
 6817        const int64_t patch_start_batch = batch_i * patches_per_batch;
 6818        const int64_t patch_end_batch   = std::min(patch_start_batch + patches_per_batch, patch_total);
 6819        const int64_t patch_n_in_batch  = patch_end_batch - patch_start_batch;
 6820
 6821        const int64_t patch_per_thread  = (patch_n_in_batch + params->nth - 1) / params->nth;
 6822        const int64_t patch_start       = patch_start_batch + params->ith * patch_per_thread;
 6823        const int64_t patch_end         = std::min(patch_start + patch_per_thread, patch_end_batch);
 6824
 6825        for (int64_t p = patch_start; p < patch_end; ++p) {
 6826            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
 6827            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
 6828            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
 6829            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
 6830            const int64_t dst_y      = p_in_depth / dst_w;
 6831            const int64_t dst_x      = p_in_depth % dst_w;
 6832
 6833            char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
 6834
 6835            for (int64_t ic = 0; ic < c; ++ic) {
 6836                for (int64_t kz = 0; kz < knl_d; ++kz) {
 6837                    for (int64_t ky = 0; ky < knl_h; ++ky) {
 6838                        for (int64_t kx = 0; kx < knl_w; ++kx) {
 6839                            const int64_t sz = dst_z * s2 + kz * d2 - p2;
 6840                            const int64_t sy = dst_y * s1 + ky * d1 - p1;
 6841                            const int64_t sx = dst_x * s0 + kx * d0 - p0;
 6842
 6843                            int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
 6844
 6845                            float src_val;
 6846                            if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
 6847                                src_val = 0.0f;
 6848                            } else {
 6849                                const int64_t cn_idx = batch_idx * c + ic;
 6850                                const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
 6851                                src_val = *src_ptr;
 6852                            }
 6853
 6854                            char * element_ptr = dst_row + dst_idx * traits->type_size;
 6855                            if (kernel_type == GGML_TYPE_F32) {
 6856                                *(float *)element_ptr = src_val;
 6857                            } else if (kernel_type == GGML_TYPE_F16) {
 6858                                *(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
 6859                            }
 6860                        }
 6861                    }
 6862                }
 6863            }
 6864        }
 6865
 6866        ggml_barrier(params->threadpool);
 6867
 6868        float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
 6869        ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
 6870
 6871        ggml_barrier(params->threadpool);
 6872
 6873        const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
 6874        const int64_t permute_start = params->ith * permute_per_thread;
 6875        const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
 6876
 6877        for (int64_t i = permute_start; i < permute_end; ++i) {
 6878            const int64_t p = patch_start_batch + i;
 6879            const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
 6880            const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
 6881            const int64_t batch_idx  = p / (dst_w * dst_h * dst_d);
 6882            const int64_t dst_z      = p_in_batch / (dst_w * dst_h);
 6883            const int64_t dst_y      = p_in_depth / dst_w;
 6884            const int64_t dst_x      = p_in_depth % dst_w;
 6885
 6886            for (int64_t ioc = 0; ioc < oc; ++ioc) {
 6887                const float value = gemm_output[i * oc + ioc];
 6888                const int64_t ocn_idx = batch_idx * oc + ioc;
 6889                float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
 6890                *dst_ptr = value;
 6891            }
 6892        }
 6893    }
 6894}
 6895
 6896void ggml_compute_forward_conv_3d(
 6897        const ggml_compute_params * params,
 6898        ggml_tensor * dst) {
 6899    const ggml_tensor * src0 = dst->src[0];
 6900    const ggml_tensor * src1 = dst->src[1];
 6901    ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
 6902}
 6903
 6904// ggml_compute_forward_conv_transpose_2d
 6905
 6906void ggml_compute_forward_conv_transpose_2d(
 6907        const ggml_compute_params * params,
 6908              ggml_tensor * dst) {
 6909
 6910    const ggml_tensor * src0 = dst->src[0];
 6911    const ggml_tensor * src1 = dst->src[1];
 6912
 6913    GGML_ASSERT(src0->type == GGML_TYPE_F16);
 6914    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 6915    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 6916
 6917    GGML_TENSOR_BINARY_OP_LOCALS
 6918
 6919    const int ith = params->ith;
 6920    const int nth = params->nth;
 6921
 6922    const int nk = ne00*ne01*ne02*ne03;
 6923
 6924    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
 6925    GGML_ASSERT(nb10 == sizeof(float));
 6926
 6927    if (ith == 0) {
 6928        memset(params->wdata, 0, params->wsize);
 6929
 6930        // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout)
 6931        {
 6932            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 6933
 6934            for (int64_t i03 = 0; i03 < ne03; i03++) {
 6935                for (int64_t i02 = 0; i02 < ne02; i02++) {
 6936                    const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02);
 6937                    ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03;
 6938                    for (int64_t i01 = 0; i01 < ne01; i01++) {
 6939                        for (int64_t i00 = 0; i00 < ne00; i00++) {
 6940                            dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00];
 6941                        }
 6942                    }
 6943                }
 6944            }
 6945        }
 6946
 6947        // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh)
 6948        {
 6949            ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk;
 6950            for (int i12 = 0; i12 < ne12; i12++) {
 6951                for (int i11 = 0; i11 < ne11; i11++) {
 6952                    const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11);
 6953                    ggml_fp16_t * dst_data = wdata + i11*ne10*ne12;
 6954                    for (int i10 = 0; i10 < ne10; i10++) {
 6955                        dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]);
 6956                    }
 6957                }
 6958            }
 6959        }
 6960
 6961        memset(dst->data, 0, ggml_nbytes(dst));
 6962    }
 6963    ggml_barrier(params->threadpool);
 6964
 6965    const int32_t stride = ggml_get_op_params_i32(dst, 0);
 6966
 6967    // total patches in dst
 6968    const int np = ne2;
 6969
 6970    // patches per thread
 6971    const int dp = (np + nth - 1)/nth;
 6972
 6973    // patch range for this thread
 6974    const int ip0 = dp*ith;
 6975    const int ip1 = MIN(ip0 + dp, np);
 6976
 6977    ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
 6978    ggml_fp16_t * const wdata_src = wdata + nk;
 6979
 6980    for (int i2 = ip0; i2 < ip1; i2++) { // Cout
 6981        float * dst_data = (float *)((char *) dst->data + i2*nb2);
 6982        ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03;
 6983        for (int i11 = 0; i11 < ne11; i11++) {
 6984            for (int i10 = 0; i10 < ne10; i10++) {
 6985                const int i1n = i11*ne10*ne12 + i10*ne12;
 6986                for (int i01 = 0; i01 < ne01; i01++) {
 6987                    for (int i00 = 0; i00 < ne00; i00++) {
 6988                        float v = 0;
 6989                        ggml_vec_dot_f16(ne03, &v, 0,
 6990                                wdata_src + i1n, 0,
 6991                                wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
 6992                        dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
 6993                    }
 6994                }
 6995            }
 6996        }
 6997    }
 6998}
 6999
 7000// ggml_compute_forward_conv_2d_dw
 7001
 7002struct ggml_conv_2d_dw_params {
 7003    int64_t channels;
 7004    int64_t batch;
 7005    int64_t src_w;
 7006    int64_t src_h;
 7007    int64_t dst_w;
 7008    int64_t dst_h;
 7009    int64_t knl_w;
 7010    int64_t knl_h;
 7011    int stride_x;
 7012    int stride_y;
 7013    int pad_x;
 7014    int pad_y;
 7015    int dilation_x;
 7016    int dilation_y;
 7017};
 7018
 7019static void ggml_compute_forward_conv_2d_dw_cwhn(
 7020        const ggml_compute_params * params,
 7021        const ggml_tensor * src,
 7022        const ggml_tensor * kernel,
 7023        ggml_tensor * dst,
 7024        const ggml_conv_2d_dw_params & p) {
 7025
 7026    const int64_t c = p.channels;
 7027    const float * knl_data = (const float *)kernel->data;
 7028
 7029    const int64_t rows_total = p.dst_h * p.batch;
 7030    const int64_t rows_per_thread = (rows_total + params->nth - 1) / params->nth;
 7031    const int64_t row_start = params->ith * rows_per_thread;
 7032    const int64_t row_end = MIN(row_start + rows_per_thread, rows_total);
 7033
 7034#ifdef GGML_SIMD
 7035    #if defined(__ARM_FEATURE_SVE)
 7036        const int64_t pkg_size = svcntw();
 7037    #else
 7038        const int64_t pkg_size = GGML_F32_EPR;
 7039    #endif
 7040    const int64_t pkg_count = c / pkg_size;
 7041    const int64_t c_pkg_end = pkg_count * pkg_size;
 7042#else
 7043    const int64_t c_pkg_end = 0;
 7044#endif
 7045
 7046    for (int64_t row = row_start; row < row_end; ++row) {
 7047        const int64_t dst_y = row % p.dst_h;
 7048        const float * src_data = (const float *)src->data + (row / p.dst_h) * p.src_w * p.src_h * c;
 7049        for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
 7050            float * dst_data = (float *)dst->data + (row * p.dst_w + dst_x) * c;
 7051            const int64_t src_y_base = dst_y * p.stride_y - p.pad_y;
 7052            const int64_t src_x_base = dst_x * p.stride_x - p.pad_x;
 7053
 7054#ifdef GGML_SIMD
 7055            // Vectorized loop
 7056            for (int64_t c_i = 0; c_i < c_pkg_end; c_i += pkg_size) {
 7057                GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
 7058                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
 7059                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
 7060                    if (src_y < 0 || src_y >= p.src_h) {
 7061                        continue;
 7062                    }
 7063                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
 7064                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
 7065                        if (src_x < 0 || src_x >= p.src_w) {
 7066                            continue;
 7067                        }
 7068                        GGML_F32_VEC k = GGML_F32_VEC_LOAD(knl_data + (knl_y * p.knl_w + knl_x) * c + c_i);
 7069                        GGML_F32_VEC s = GGML_F32_VEC_LOAD(src_data + (src_y * p.src_w + src_x) * c + c_i);
 7070                        sum = GGML_F32_VEC_FMA(sum, k, s);
 7071                    }
 7072                }
 7073                GGML_F32_VEC_STORE(dst_data + c_i, sum);
 7074            }
 7075#endif
 7076            // Scalar loop
 7077            for (int64_t c_i = c_pkg_end; c_i < c; ++c_i) {
 7078                float sum = 0.0f;
 7079                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
 7080                    const int64_t src_y = src_y_base + knl_y * p.dilation_y;
 7081                    if (src_y < 0 || src_y >= p.src_h) {
 7082                        continue;
 7083                    }
 7084                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
 7085                        const int64_t src_x = src_x_base + knl_x * p.dilation_x;
 7086                        if (src_x < 0 || src_x >= p.src_w) {
 7087                            continue;
 7088                        }
 7089                        sum += knl_data[(knl_y * p.knl_w + knl_x) * c + c_i]
 7090                             * src_data[(src_y * p.src_w + src_x) * c + c_i];
 7091                    }
 7092                }
 7093                dst_data[c_i] = sum;
 7094            }
 7095        }
 7096    }
 7097}
 7098
 7099static void ggml_compute_forward_conv_2d_dw_whcn(
 7100        const ggml_compute_params * params,
 7101        const ggml_tensor * src,
 7102        const ggml_tensor * kernel,
 7103        ggml_tensor * dst,
 7104        const ggml_conv_2d_dw_params & p) {
 7105
 7106    const int64_t n = p.channels * p.batch;
 7107    const int64_t per_thread = (n + params->nth - 1) / params->nth;
 7108    const int64_t start = params->ith * per_thread;
 7109    const int64_t end = MIN(start + per_thread, n);
 7110
 7111    for (int64_t i = start; i < end; ++i) {
 7112        const float * knl_data = (const float *)kernel->data + (i % p.channels) * p.knl_w * p.knl_h;
 7113        const float * src_data = (const float *)src->data + i * p.src_w * p.src_h;
 7114        float * dst_data = (float *)dst->data + i * p.dst_w * p.dst_h;
 7115
 7116        for (int64_t dst_y = 0; dst_y < p.dst_h; ++dst_y) {
 7117            for (int64_t dst_x = 0; dst_x < p.dst_w; ++dst_x) {
 7118
 7119                float sum = 0.0f;
 7120                for (int64_t knl_y = 0; knl_y < p.knl_h; ++knl_y) {
 7121                    const int64_t src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y;
 7122                    if (src_y < 0 || src_y >= p.src_h) {
 7123                        continue;
 7124                    }
 7125                    for (int64_t knl_x = 0; knl_x < p.knl_w; ++knl_x) {
 7126                        const int64_t src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x;
 7127                        if (src_x < 0 || src_x >= p.src_w) {
 7128                            continue;
 7129                        }
 7130                        sum += knl_data[knl_y * p.knl_w + knl_x]
 7131                             * src_data[src_y * p.src_w + src_x];
 7132                    }
 7133                }
 7134                dst_data[dst_y * p.dst_w + dst_x] = sum;
 7135            }
 7136        }
 7137    }
 7138}
 7139
 7140void ggml_compute_forward_conv_2d_dw(
 7141        const ggml_compute_params * params,
 7142        ggml_tensor * dst) {
 7143
 7144    const ggml_tensor * kernel = dst->src[0];
 7145    const ggml_tensor * src = dst->src[1];
 7146    ggml_conv_2d_dw_params p;
 7147    p.channels = src->ne[2];
 7148    p.batch = src->ne[3];
 7149    p.src_w = src->ne[0];
 7150    p.src_h = src->ne[1];
 7151    p.dst_w = dst->ne[0];
 7152    p.dst_h = dst->ne[1];
 7153    p.knl_w = kernel->ne[0];
 7154    p.knl_h = kernel->ne[1];
 7155    p.stride_x = dst->op_params[0];
 7156    p.stride_y = dst->op_params[1];
 7157    p.pad_x = dst->op_params[2];
 7158    p.pad_y = dst->op_params[3];
 7159    p.dilation_x = dst->op_params[4];
 7160    p.dilation_y = dst->op_params[5];
 7161
 7162    GGML_ASSERT(kernel->ne[3] == p.channels);
 7163    GGML_ASSERT(dst->ne[3] == p.batch);
 7164
 7165    if (ggml_is_contiguous(src)) {
 7166        ggml_compute_forward_conv_2d_dw_whcn(params, src, kernel, dst, p);
 7167    } else if (ggml_is_contiguous_channels(src)) {
 7168        // kernel should also have channels most contiguous in memory
 7169        GGML_ASSERT(kernel->nb[0] >= kernel->nb[2] && kernel->nb[1] >= kernel->nb[0]);
 7170        ggml_compute_forward_conv_2d_dw_cwhn(params, src, kernel, dst, p);
 7171    } else {
 7172        GGML_ABORT("non-contiguous memory layout not supported");
 7173    }
 7174}
 7175
 7176// ggml_compute_forward_pool_1d_ksp
 7177static void ggml_compute_forward_pool_1d_ksp(
 7178        const ggml_compute_params * params,
 7179        const ggml_op_pool op,
 7180        const int k,
 7181        const int s,
 7182        const int p,
 7183        ggml_tensor * dst) {
 7184
 7185    const ggml_tensor * src = dst->src[0];
 7186
 7187    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
 7188
 7189    if (params->ith != 0) {
 7190        return;
 7191    }
 7192
 7193    const int64_t IW = src->ne[0];
 7194    const int64_t OW = dst->ne[0];
 7195
 7196    const int64_t nr = ggml_nrows(src);
 7197
 7198    for (int64_t ir = 0; ir < nr; ++ir) {
 7199        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];
 7200        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);
 7201
 7202        for (int64_t ow = 0; ow < OW; ++ow) {
 7203            float res = 0;
 7204            switch (op) {
 7205                case GGML_OP_POOL_AVG: res = 0.0f;     break;
 7206                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
 7207                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
 7208            }
 7209
 7210            int count = 0;
 7211            const int base = (int) ow * s - p;
 7212
 7213            for (int ki = 0; ki < k; ++ki) {
 7214                const int j = base + ki;
 7215                if (j < 0 || j >= (int) IW) {
 7216                    continue;
 7217                }
 7218
 7219                float v;
 7220                if (src->type == GGML_TYPE_F32) {
 7221                    v = ((const float *) srow_bytes)[j];
 7222                } else {
 7223                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
 7224                }
 7225
 7226                switch (op) {
 7227                    case GGML_OP_POOL_AVG: res += v;                break;
 7228                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;
 7229                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
 7230                }
 7231
 7232                ++count;
 7233            }
 7234
 7235            switch (op) {
 7236                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
 7237                case GGML_OP_POOL_MAX:                                           break;
 7238                case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
 7239            }
 7240
 7241            drow[ow] = res;
 7242        }
 7243    }
 7244}
 7245
 7246// ggml_compute_forward_pool_1d
 7247
 7248void ggml_compute_forward_pool_1d(
 7249        const ggml_compute_params * params,
 7250              ggml_tensor * dst) {
 7251
 7252    const int32_t * opts = (const int32_t *)dst->op_params;
 7253    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
 7254    const int k0 = opts[1];
 7255    const int s0 = opts[2];
 7256    const int p0 = opts[3];
 7257
 7258    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
 7259}
 7260
 7261// ggml_compute_forward_pool_2d
 7262
 7263void ggml_compute_forward_pool_2d(
 7264        const ggml_compute_params * params,
 7265        ggml_tensor * dst) {
 7266
 7267    const ggml_tensor * src = dst->src[0];
 7268
 7269    assert(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16);
 7270
 7271    if (params->ith != 0) {
 7272        return;
 7273    }
 7274
 7275    const int32_t * opts = (const int32_t *)dst->op_params;
 7276
 7277    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
 7278    const int k0 = opts[1];
 7279    const int k1 = opts[2];
 7280    const int s0 = opts[3];
 7281    const int s1 = opts[4];
 7282    const int p0 = opts[5];
 7283    const int p1 = opts[6];
 7284    const char * cdata = (const char*)src->data;
 7285    const char * const data_end = cdata + ggml_nbytes(src);
 7286
 7287    const int64_t px = dst->ne[0];
 7288    const int64_t py = dst->ne[1];
 7289    const int64_t pa = px * py;
 7290
 7291    float * dplane = (float *)dst->data;
 7292
 7293    const int ka = k0 * k1;
 7294    const int offset0 = -p0;
 7295    const int offset1 = -p1;
 7296
 7297    while (cdata < data_end) {
 7298        for (int oy = 0; oy < py; ++oy) {
 7299            float * const drow = dplane + oy * px;
 7300            float * const out  = drow;
 7301
 7302            for (int ox = 0; ox < px; ++ox) {
 7303                float res = 0;
 7304                switch (op) {
 7305                    case GGML_OP_POOL_AVG: res = 0;        break;
 7306                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
 7307                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
 7308                }
 7309
 7310                const int ix = offset0 + ox * s0;
 7311                const int iy = offset1 + oy * s1;
 7312
 7313                for (int ky = 0; ky < k1; ++ky) {
 7314                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {
 7315                        continue;
 7316                    }
 7317
 7318                    const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
 7319                    for (int kx = 0; kx < k0; ++kx) {
 7320                        int j = ix + kx;
 7321                        if (j < 0 || j >= src->ne[0]) {
 7322                            continue;
 7323                        }
 7324
 7325                        const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
 7326                        switch (op) {
 7327                            case GGML_OP_POOL_AVG: res += srow_j;                break;
 7328                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;
 7329                            case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
 7330                        }
 7331                    }
 7332                }
 7333                switch (op) {
 7334                    case GGML_OP_POOL_AVG:           res /= ka; break;
 7335                    case GGML_OP_POOL_MAX:                      break;
 7336                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
 7337                }
 7338
 7339                out[ox] = res;
 7340            }
 7341        }
 7342
 7343        cdata  += src->nb[2];
 7344        dplane += pa;
 7345    }
 7346}
 7347
 7348// ggml_compute_forward_pool_2d_back
 7349
 7350void ggml_compute_forward_pool_2d_back(
 7351        const ggml_compute_params * params,
 7352        ggml_tensor * dst) {
 7353
 7354    const ggml_tensor * src  = dst->src[0];
 7355    const ggml_tensor * dstf = dst->src[1]; // forward tensor of dst
 7356
 7357    assert(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
 7358
 7359    if (params->ith != 0) {
 7360        return;
 7361    }
 7362
 7363    const int32_t * opts = (const int32_t *)dst->op_params;
 7364    ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
 7365    const int k0 = opts[1];
 7366    const int k1 = opts[2];
 7367    const int s0 = opts[3];
 7368    const int s1 = opts[4];
 7369    const int p0 = opts[5];
 7370    const int p1 = opts[6];
 7371
 7372    char       * cdata  = (char       *) dst->data;
 7373    const char * cdataf = (const char *) dstf->data;
 7374    const char * const data_end = cdata + ggml_nbytes(dst);
 7375
 7376    GGML_ASSERT(params->ith == 0);
 7377    memset(cdata, 0, ggml_nbytes(dst));
 7378
 7379    const int64_t px = src->ne[0];
 7380    const int64_t py = src->ne[1];
 7381    const int64_t pa = px * py;
 7382
 7383    const float * splane = (const float *) src->data;
 7384
 7385    const int ka = k0 * k1;
 7386    const int offset0 = -p0;
 7387    const int offset1 = -p1;
 7388
 7389    while (cdata < data_end) {
 7390        for (int oy = 0; oy < py; ++oy) {
 7391            const float * const srow = splane + oy * px;
 7392            for (int ox = 0; ox < px; ++ox) {
 7393                const float grad0 = srow[ox];
 7394
 7395                const int ix = offset0 + ox * s0;
 7396                const int iy = offset1 + oy * s1;
 7397
 7398                if (op == GGML_OP_POOL_MAX) {
 7399                    float maxval = -FLT_MAX;
 7400                    int kxmax = -1;
 7401                    int kymax = -1;
 7402
 7403                    for (int ky = 0; ky < k1; ++ky) {
 7404                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
 7405                            continue;
 7406                        }
 7407                        const void * drowf = (const void *)(cdataf + dst->nb[1] * (iy + ky));
 7408                        for (int kx = 0; kx < k0; ++kx) {
 7409                            int j = ix + kx;
 7410                            if (j < 0 || j >= dst->ne[0]) {
 7411                                continue;
 7412                            }
 7413
 7414                            const float val = dst->type == GGML_TYPE_F32 ?
 7415                                ((const float *) drowf)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drowf)[j]);
 7416                            if (val <= maxval) {
 7417                                continue;
 7418                            }
 7419
 7420                            maxval = val;
 7421                            kxmax = kx;
 7422                            kymax = ky;
 7423                        }
 7424                    }
 7425
 7426                    if (kxmax == -1 || kymax == -1) {
 7427                        continue;
 7428                    }
 7429
 7430                    void * drow = (void *)(cdata + dst->nb[1] * (iy + kymax));
 7431                    const int j = ix + kxmax;
 7432                    if (dst->type == GGML_TYPE_F32) {
 7433                        ((float *) drow)[j] += grad0;
 7434                    } else {
 7435                        ((ggml_fp16_t *) drow)[j] = GGML_CPU_FP32_TO_FP16(grad0 + GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) drow)[j]));
 7436                    }
 7437                } else if (op == GGML_OP_POOL_AVG) {
 7438                    const float grad = grad0 / ka;
 7439
 7440                    for (int ky = 0; ky < k1; ++ky) {
 7441                        if (iy + ky < 0 || iy + ky >= dst->ne[1]) {
 7442                            continue;
 7443                        }
 7444                        void * drow = (void *)(cdata + dst->nb[1] * (iy + ky));
 7445                        for (int kx = 0; kx < k0; ++kx) {
 7446                            int j = ix + kx;
 7447                            if (j < 0 || j >= dst->ne[0]) {
 7448                                continue;
 7449                            }
 7450
 7451                            if (dst->type == GGML_TYPE_F32) {
 7452                                ((float *) drow)[j] += grad;
 7453                            } else {
 7454                                ((ggml_fp16_t *) drow)[j] += GGML_CPU_FP32_TO_FP16(grad);
 7455                            }
 7456                        }
 7457                    }
 7458                } else {
 7459                    GGML_ASSERT(false);
 7460                }
 7461            }
 7462        }
 7463
 7464        cdata  += dst->nb[2];
 7465        cdataf += dst->nb[2];
 7466        splane += pa;
 7467    }
 7468}
 7469
 7470// ggml_compute_forward_upscale
 7471
 7472static void ggml_compute_forward_upscale_f32(
 7473    const ggml_compute_params * params,
 7474    ggml_tensor * dst) {
 7475
 7476    const ggml_tensor * src0 = dst->src[0];
 7477
 7478    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 7479
 7480    const int ith = params->ith;
 7481    const int nth = params->nth;
 7482
 7483    GGML_TENSOR_UNARY_OP_LOCALS
 7484
 7485    float sf0 = (float)ne0/src0->ne[0];
 7486    float sf1 = (float)ne1/src0->ne[1];
 7487    float sf2 = (float)ne2/src0->ne[2];
 7488    float sf3 = (float)ne3/src0->ne[3];
 7489    float pixel_offset = 0.5f;
 7490
 7491    const int32_t mode_flags = ggml_get_op_params_i32(dst, 0);
 7492    const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
 7493
 7494    if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
 7495        pixel_offset = 0.0f;
 7496        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
 7497        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
 7498    }
 7499
 7500    if (mode == GGML_SCALE_MODE_NEAREST) {
 7501        for (int64_t i3 = 0; i3 < ne3; i3++) {
 7502            const int64_t i03 = i3 / sf3;
 7503            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
 7504                const int64_t i02 = i2 / sf2;
 7505                for (int64_t i1 = 0; i1 < ne1; i1++) {
 7506                    const int64_t i01 = i1 / sf1;
 7507                    for (int64_t i0 = 0; i0 < ne0; i0++) {
 7508                        const int64_t i00 = i0 / sf0;
 7509
 7510                        const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 7511                              float * y = (float *)((char *)  dst->data +  i0*nb0  +  i1*nb1  +  i2*nb2  +  i3*nb3);
 7512
 7513                        *y = *x;
 7514                    }
 7515                }
 7516            }
 7517        }
 7518    } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) {
 7519        // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True)
 7520        // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp
 7521        auto triangle_filter = [](float x) -> float {
 7522            return std::max(1.0f - fabsf(x), 0.0f);
 7523        };
 7524
 7525        // support and invscale, minimum 1 pixel for bilinear
 7526        const float support1  = std::max(1.0f, 1.0f / sf1);
 7527        const float invscale1 = 1.0f / support1;
 7528        const float support0  = std::max(1.0f, 1.0f / sf0);
 7529        const float invscale0 = 1.0f / support0;
 7530
 7531        for (int64_t i3 = 0; i3 < ne3; i3++) {
 7532            const int64_t i03 = i3 / sf3;
 7533            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
 7534                const int64_t i02 = i2 / sf2;
 7535                for (int64_t i1 = 0; i1 < ne1; i1++) {
 7536                    const float y = ((float) i1 + pixel_offset) / sf1;
 7537                    for (int64_t i0 = 0; i0 < ne0; i0++) {
 7538                        const float x = ((float) i0 + pixel_offset) / sf0;
 7539
 7540                        // the range of source pixels that contribute
 7541                        const int64_t x_min = std::max<int64_t>(x - support0 + pixel_offset, 0);
 7542                        const int64_t x_max = std::min<int64_t>(x + support0 + pixel_offset, ne00);
 7543                        const int64_t y_min = std::max<int64_t>(y - support1 + pixel_offset, 0);
 7544                        const int64_t y_max = std::min<int64_t>(y + support1 + pixel_offset, ne01);
 7545
 7546                        // bilinear filter with antialiasing
 7547                        float val = 0.0f;
 7548                        float total_weight = 0.0f;
 7549
 7550                        for (int64_t sy = y_min; sy < y_max; sy++) {
 7551                            const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1);
 7552
 7553                            for (int64_t sx = x_min; sx < x_max; sx++) {
 7554                                const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0);
 7555                                const float weight = weight_x * weight_y;
 7556
 7557                                if (weight <= 0.0f) {
 7558                                    continue;
 7559                                }
 7560
 7561                                const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03);
 7562                                val += pixel * weight;
 7563                                total_weight += weight;
 7564                            }
 7565                        }
 7566
 7567                        if (total_weight > 0.0f) {
 7568                            val /= total_weight;
 7569                        }
 7570
 7571                        float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
 7572                        *dst_ptr = val;
 7573                    }
 7574                }
 7575            }
 7576        }
 7577    } else if (mode == GGML_SCALE_MODE_BILINEAR) {
 7578        for (int64_t i3 = 0; i3 < ne3; i3++) {
 7579            const int64_t i03 = i3 / sf3;
 7580            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
 7581                const int64_t i02 = i2 / sf2;
 7582                for (int64_t i1 = 0; i1 < ne1; i1++) {
 7583                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
 7584                    int64_t y0 = (int64_t)floorf(y);
 7585                    int64_t y1 = y0 + 1;
 7586
 7587                    y0 = std::max(int64_t(0), std::min(y0, ne01 - 1));
 7588                    y1 = std::max(int64_t(0), std::min(y1, ne01 - 1));
 7589
 7590                    float dy = y - (float)y0;
 7591                    dy = std::max(0.0f, std::min(dy, 1.0f));
 7592
 7593                    for (int64_t i0 = 0; i0 < ne0; i0++) {
 7594                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
 7595                        int64_t x0 = (int64_t)floorf(x);
 7596                        int64_t x1 = x0 + 1;
 7597
 7598                        x0 = std::max(int64_t(0), std::min(x0, ne00 - 1));
 7599                        x1 = std::max(int64_t(0), std::min(x1, ne00 - 1));
 7600
 7601                        float dx = x - (float)x0;
 7602                        dx = std::max(0.0f, std::min(dx, 1.0f));
 7603
 7604                        // fetch the four surrounding pixel values and interpolate
 7605                        const float a = *(const float *)((const char *)src0->data + x0*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
 7606                        const float b = *(const float *)((const char *)src0->data + x1*nb00 + y0*nb01 + i02*nb02 + i03*nb03);
 7607                        const float c = *(const float *)((const char *)src0->data + x0*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
 7608                        const float d = *(const float *)((const char *)src0->data + x1*nb00 + y1*nb01 + i02*nb02 + i03*nb03);
 7609
 7610                        const float val = a*(1 - dx)*(1 - dy) + b*dx*(1 - dy) + c*(1 - dx)*dy + d*dx*dy;
 7611
 7612                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
 7613                        *y_dst = val;
 7614                    }
 7615                }
 7616            }
 7617        }
 7618    } else if (mode == GGML_SCALE_MODE_BICUBIC) {
 7619        // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
 7620        const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
 7621        auto weight1 = [a](float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
 7622        auto weight2 = [a](float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
 7623        auto bicubic = [=](float p0, float p1, float p2, float p3, float x) {
 7624            const float w0 = weight2(x + 1);
 7625            const float w1 = weight1(x + 0);
 7626            const float w2 = weight1(1 - x);
 7627            const float w3 = weight2(2 - x);
 7628            return p0*w0 + p1*w1 + p2*w2 + p3*w3;
 7629        };
 7630
 7631        for (int64_t i3 = 0; i3 < ne3; i3++) {
 7632            const int64_t i03 = i3 / sf3;
 7633            for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
 7634                const int64_t i02 = i2 / sf2;
 7635                for (int64_t i1 = 0; i1 < ne1; i1++) {
 7636                    const float y = ((float)i1 + pixel_offset) / sf1 - pixel_offset;
 7637                    const int64_t y0 = (int64_t)floorf(y);
 7638                    const float dy = y - (float)y0;
 7639
 7640                    for (int64_t i0 = 0; i0 < ne0; i0++) {
 7641                        const float x = ((float)i0 + pixel_offset) / sf0 - pixel_offset;
 7642                        const int64_t x0 = (int64_t)floorf(x);
 7643                        const float dx = x - (float)x0;
 7644
 7645                        auto p = [=](int64_t x_off, int64_t y_off) -> float {
 7646                            int64_t i00 = std::max(int64_t(0), std::min(x0 + x_off, ne00 - 1));
 7647                            int64_t i01 = std::max(int64_t(0), std::min(y0 + y_off, ne01 - 1));
 7648                            return *(const float *)((const char *)src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
 7649                        };
 7650
 7651                        const float val = bicubic(
 7652                            bicubic(p(-1,-1), p(0,-1), p(1,-1), p(2,-1), dx),
 7653                            bicubic(p(-1, 0), p(0, 0), p(1, 0), p(2, 0), dx),
 7654                            bicubic(p(-1, 1), p(0, 1), p(1, 1), p(2, 1), dx),
 7655                            bicubic(p(-1, 2), p(0, 2), p(1, 2), p(2, 2), dx), dy);
 7656
 7657                        float * y_dst = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
 7658                        *y_dst = val;
 7659                    }
 7660                }
 7661            }
 7662        }
 7663    } else {
 7664        GGML_ABORT("unsupported upscale mode");
 7665    }
 7666}
 7667
 7668void ggml_compute_forward_upscale(
 7669    const ggml_compute_params * params,
 7670    ggml_tensor * dst) {
 7671
 7672    const ggml_tensor * src0 = dst->src[0];
 7673
 7674    switch (src0->type) {
 7675        case GGML_TYPE_F32:
 7676            {
 7677                ggml_compute_forward_upscale_f32(params, dst);
 7678            } break;
 7679        default:
 7680            {
 7681                GGML_ABORT("fatal error");
 7682            }
 7683    }
 7684}
 7685
 7686
 7687// ggml_compute_forward_pad
 7688
 7689template<bool circular_t>
 7690static void ggml_compute_forward_pad_f32(
 7691    const ggml_compute_params * params,
 7692          ggml_tensor * dst) {
 7693
 7694    const ggml_tensor * src0 = dst->src[0];
 7695
 7696    assert(dst->nb[0] == sizeof(float));
 7697
 7698    const int ith = params->ith;
 7699    const int nth = params->nth;
 7700
 7701    GGML_TENSOR_UNARY_OP_LOCALS
 7702
 7703    float * dst_ptr = (float *) dst->data;
 7704    const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
 7705    const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
 7706    const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
 7707    const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
 7708    const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
 7709    const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
 7710    const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
 7711    const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
 7712
 7713    // TODO: optimize
 7714
 7715    for (int64_t i2 = 0; i2 < ne2; ++i2) {
 7716        for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
 7717            for (int64_t i0 = 0; i0 < ne0; ++i0) {
 7718                for (int64_t i3 = 0; i3 < ne3; ++i3) {
 7719                    // circular means wrap around on a torus, so x and y loop around
 7720                    if constexpr (circular_t) {
 7721                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
 7722                        const int64_t src_i0 = ggml_wrap_around(i0 - lp0, ne00);
 7723                        const int64_t src_i1 = ggml_wrap_around(i1 - lp1, ne01);
 7724                        const int64_t src_i2 = ggml_wrap_around(i2 - lp2, ne02);
 7725                        const int64_t src_i3 = ggml_wrap_around(i3 - lp3, ne03);
 7726
 7727                        const int64_t src_idx =
 7728                            src_i3*nb03 +
 7729                            src_i2*nb02 +
 7730                            src_i1*nb01 +
 7731                            src_i0*nb00;
 7732
 7733                        const float * src_ptr = (const float *)((char *) src0->data + src_idx);
 7734                        dst_ptr[dst_idx] = *src_ptr;
 7735                    } else {
 7736                        const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
 7737                        if ((i0 >= lp0 && i0 < ne0 - rp0) \
 7738                            && (i1 >= lp1 && i1 < ne1 - rp1) \
 7739                            && (i2 >= lp2 && i2 < ne2 - rp2) \
 7740                            && (i3 >= lp3 && i3 < ne3 - rp3)) {
 7741                            const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
 7742                            const float * src_ptr = (const float *)((char *) src0->data + src_idx);
 7743                            dst_ptr[dst_idx] = *src_ptr;
 7744                        } else {
 7745                            dst_ptr[dst_idx] = 0;
 7746                        }
 7747                    }
 7748                }
 7749            }
 7750        }
 7751    }
 7752}
 7753
 7754
 7755void ggml_compute_forward_pad(
 7756    const ggml_compute_params * params,
 7757    ggml_tensor * dst) {
 7758    const ggml_tensor * src0 = dst->src[0];
 7759    const bool circular = (bool) ggml_get_op_params_i32(dst, 8);
 7760    switch (src0->type) {
 7761        case GGML_TYPE_F32:
 7762            {
 7763                if (circular) {
 7764                    ggml_compute_forward_pad_f32<true>(params, dst);
 7765                } else {
 7766                    ggml_compute_forward_pad_f32<false>(params, dst);
 7767                }
 7768            } break;
 7769        default:
 7770            {
 7771                GGML_ABORT("fatal error");
 7772            }
 7773    }
 7774}
 7775
 7776// ggml_compute_forward_pad_reflect_1d
 7777
 7778void ggml_compute_forward_pad_reflect_1d(
 7779        const ggml_compute_params * params,
 7780              ggml_tensor * dst) {
 7781
 7782    const ggml_tensor * src0 = dst->src[0];
 7783
 7784    GGML_ASSERT(src0->type == GGML_TYPE_F32);
 7785    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 7786
 7787    const int ith = params->ith;
 7788    const int nth = params->nth;
 7789
 7790    const int32_t * opts = (const int32_t *) dst->op_params;
 7791    const int p0 = opts[0];
 7792    const int p1 = opts[1];
 7793
 7794    GGML_TENSOR_UNARY_OP_LOCALS
 7795
 7796    for (int64_t i3 = 0; i3 < ne3; i3++) {
 7797        for (int64_t i2 = 0; i2 < ne2; i2++) {
 7798            for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
 7799                float * left  = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 +         p0*nb0);
 7800                float * right = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + (ne0-p1-1)*nb0);
 7801
 7802                ggml_vec_cpy_f32(ne00, left, (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
 7803
 7804                for (int i0 = 1; i0 <= p0; i0++) { left[-i0] = left[i0];   }
 7805                for (int i0 = 1; i0 <= p1; i0++) { right[i0] = right[-i0]; }
 7806            }
 7807        }
 7808    }
 7809}
 7810
 7811// ggml_compute_forward_roll
 7812
 7813static int64_t ggml_wrap_index(int64_t i, int64_t ne) {
 7814    if (i < 0) {
 7815        return i + ne;
 7816    } else if (i >= ne) {
 7817        return i - ne;
 7818    }
 7819    return i;
 7820}
 7821
 7822static void ggml_compute_forward_roll_f32(
 7823        const ggml_compute_params * params,
 7824        ggml_tensor * dst) {
 7825
 7826    const ggml_tensor * src0 = dst->src[0];
 7827    const float * src_data = (const float *) src0->data;
 7828    float * dst_data = (float *) dst->data;
 7829
 7830    GGML_TENSOR_UNARY_OP_LOCALS
 7831
 7832    const int s0 = ggml_get_op_params_i32(dst, 0);
 7833    const int s1 = ggml_get_op_params_i32(dst, 1);
 7834    const int s2 = ggml_get_op_params_i32(dst, 2);
 7835    const int s3 = ggml_get_op_params_i32(dst, 3);
 7836
 7837    const int64_t total = ne1 * ne2 * ne3;
 7838    const int64_t per_thread = (total + params->nth) / params->nth;
 7839    const int64_t start = params->ith * per_thread;
 7840    const int64_t end   = std::min(start + per_thread, total);
 7841
 7842    for (int64_t i = start; i < end; ++i) {
 7843        const int64_t i1 = i % ne1;
 7844        const int64_t i2 = (i / ne1) % ne2;
 7845        const int64_t i3 = i / (ne2 * ne1);
 7846        float * dst_row = dst_data + (i3*nb3 + i2*nb2 + i1*nb1) / sizeof(float);
 7847
 7848        const int64_t i01 = ggml_wrap_index(i1 - s1, ne01);
 7849        const int64_t i02 = ggml_wrap_index(i2 - s2, ne02);
 7850        const int64_t i03 = ggml_wrap_index(i3 - s3, ne03);
 7851        const float * src_row = src_data + (i03*nb03 + i02*nb02 + i01*nb01) / sizeof(float);
 7852
 7853        const int64_t s = ggml_wrap_index(-s0, ne00);
 7854        const int64_t n = ne00 - s;
 7855        ggml_vec_cpy_f32(n, dst_row,     src_row + s);
 7856        ggml_vec_cpy_f32(s, dst_row + n, src_row);
 7857    }
 7858}
 7859
 7860void ggml_compute_forward_roll(
 7861        const ggml_compute_params * params,
 7862        ggml_tensor * dst) {
 7863
 7864    const ggml_tensor * src0 = dst->src[0];
 7865
 7866    switch (src0->type) {
 7867        case GGML_TYPE_F32:
 7868            {
 7869                ggml_compute_forward_roll_f32(params, dst);
 7870            } break;
 7871        default:
 7872            {
 7873                GGML_ABORT("fatal error");
 7874            }
 7875    }
 7876}
 7877
 7878// ggml_compute_forward_arange
 7879
 7880static void ggml_compute_forward_arange_f32(
 7881    const ggml_compute_params * params,
 7882    ggml_tensor * dst) {
 7883
 7884    GGML_ASSERT(dst->nb[0] == sizeof(float));
 7885
 7886    const int ith = params->ith;
 7887    const int nth = params->nth;
 7888
 7889    const float start = ggml_get_op_params_f32(dst, 0);
 7890    const float stop  = ggml_get_op_params_f32(dst, 1);
 7891    const float step  = ggml_get_op_params_f32(dst, 2);
 7892
 7893    const int64_t steps = (int64_t) ceilf((stop - start) / step);
 7894
 7895    GGML_ASSERT(ggml_nelements(dst) == steps);
 7896
 7897    for (int64_t i = ith; i < steps; i+= nth) {
 7898        float value = start + step * i;
 7899        ((float *)dst->data)[i] = value;
 7900    }
 7901}
 7902
 7903void ggml_compute_forward_arange(
 7904    const ggml_compute_params * params,
 7905    ggml_tensor * dst) {
 7906    switch (dst->type) {
 7907        case GGML_TYPE_F32:
 7908            {
 7909                ggml_compute_forward_arange_f32(params, dst);
 7910            } break;
 7911        default:
 7912            {
 7913                GGML_ABORT("fatal error");
 7914            }
 7915    }
 7916}
 7917
 7918static void ggml_compute_forward_timestep_embedding_f32(
 7919    const ggml_compute_params * params,
 7920    ggml_tensor * dst) {
 7921
 7922    const ggml_tensor * src0 = dst->src[0];
 7923
 7924    GGML_ASSERT(src0->nb[0] == sizeof(float));
 7925
 7926    const int ith = params->ith;
 7927    const int nth = params->nth;
 7928
 7929    GGML_TENSOR_UNARY_OP_LOCALS
 7930
 7931    const int dim = ggml_get_op_params_i32(dst, 0);
 7932    const int max_period = ggml_get_op_params_i32(dst, 1);
 7933
 7934    int half = dim / 2;
 7935
 7936    for (int64_t i = 0; i < ne00; i++) {
 7937        float * embed_data = (float *)((char *)  dst->data +  i*nb1);
 7938        for (int64_t j = ith; j < half; j += nth) {
 7939            float timestep = ((float *)src0->data)[i];
 7940            float freq = (float)expf(-logf(max_period) * j / half);
 7941            float arg = timestep * freq;
 7942            embed_data[j] = cosf(arg);
 7943            embed_data[j + half] = sinf(arg);
 7944        }
 7945        if (dim % 2 != 0 && ith == 0) {
 7946            embed_data[2 * half] = 0.f;
 7947        }
 7948    }
 7949}
 7950
 7951void ggml_compute_forward_timestep_embedding(
 7952    const ggml_compute_params * params,
 7953    ggml_tensor * dst) {
 7954
 7955    const ggml_tensor * src0 = dst->src[0];
 7956
 7957    switch (src0->type) {
 7958        case GGML_TYPE_F32:
 7959            {
 7960                ggml_compute_forward_timestep_embedding_f32(params, dst);
 7961            } break;
 7962        default:
 7963            {
 7964                GGML_ABORT("fatal error");
 7965            }
 7966    }
 7967}
 7968
 7969// ggml_compute_forward_argsort
 7970
 7971template<enum ggml_sort_order order>
 7972struct cmp_argsort {
 7973    const float * data;
 7974    bool operator()(int32_t a, int32_t b) const {
 7975        if constexpr (order == GGML_SORT_ORDER_ASC) {
 7976            return data[a] < data[b];
 7977        } else {
 7978            return data[a] > data[b];
 7979        }
 7980    }
 7981};
 7982
 7983static void ggml_compute_forward_argsort_f32(
 7984    const ggml_compute_params * params,
 7985    ggml_tensor * dst) {
 7986
 7987    const ggml_tensor * src0 = dst->src[0];
 7988
 7989    GGML_TENSOR_UNARY_OP_LOCALS
 7990
 7991    GGML_ASSERT(nb0 == sizeof(float));
 7992
 7993    const int ith = params->ith;
 7994    const int nth = params->nth;
 7995
 7996    const int64_t nr = ggml_nrows(src0);
 7997
 7998    ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
 7999
 8000    for (int64_t i = ith; i < nr; i += nth) {
 8001        const float * src_data = (float *)((char *) src0->data + i*nb01);
 8002
 8003        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
 8004
 8005        for (int64_t j = 0; j < ne0; j++) {
 8006            dst_data[j] = j;
 8007        }
 8008
 8009        switch (order) {
 8010            case GGML_SORT_ORDER_ASC:
 8011                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_ASC>{src_data});
 8012                break;
 8013
 8014            case GGML_SORT_ORDER_DESC:
 8015                std::sort(dst_data, dst_data + ne0, cmp_argsort<GGML_SORT_ORDER_DESC>{src_data});
 8016                break;
 8017
 8018            default:
 8019                GGML_ABORT("invalid sort order");
 8020        }
 8021    }
 8022}
 8023
 8024void ggml_compute_forward_argsort(
 8025    const ggml_compute_params * params,
 8026    ggml_tensor * dst) {
 8027
 8028    const ggml_tensor * src0 = dst->src[0];
 8029
 8030    switch (src0->type) {
 8031        case GGML_TYPE_F32:
 8032            {
 8033                ggml_compute_forward_argsort_f32(params, dst);
 8034            } break;
 8035        default:
 8036            {
 8037                GGML_ABORT("fatal error");
 8038            }
 8039    }
 8040}
 8041
 8042// ggml_compute_forward_top_k
 8043
 8044struct cmp_top_k {
 8045    const float * data;
 8046    bool operator()(int32_t a, int32_t b) const {
 8047        return data[a] > data[b];
 8048    }
 8049};
 8050
 8051static void ggml_compute_forward_top_k_f32(
 8052    const ggml_compute_params * params,
 8053    ggml_tensor * dst) {
 8054
 8055    const ggml_tensor * src0 = dst->src[0];
 8056
 8057    GGML_TENSOR_UNARY_OP_LOCALS
 8058
 8059    GGML_ASSERT(nb0 == sizeof(float));
 8060
 8061    const int ith = params->ith;
 8062    const int nth = params->nth;
 8063
 8064    const int64_t nr = ggml_nrows(src0);
 8065
 8066    const int top_k = ne0;
 8067
 8068    int32_t * tmp = (int32_t *) params->wdata + (ne00 + CACHE_LINE_SIZE_F32) * ith;
 8069
 8070    for (int64_t i = ith; i < nr; i += nth) {
 8071        const float * src_data = (float *)((char *) src0->data + i*nb01);
 8072
 8073        for (int64_t j = 0; j < ne00; j++) {
 8074            tmp[j] = j;
 8075        }
 8076
 8077        std::partial_sort(tmp, tmp + top_k, tmp + ne00, cmp_top_k{src_data});
 8078
 8079        int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
 8080
 8081        std::copy(tmp, tmp + top_k, dst_data);
 8082
 8083        // emphasize that the order is not important
 8084        if (top_k > 1) {
 8085            std::swap(dst_data[0], dst_data[1]);
 8086        }
 8087    }
 8088}
 8089
 8090void ggml_compute_forward_top_k(
 8091    const ggml_compute_params * params,
 8092    ggml_tensor * dst) {
 8093
 8094    const ggml_tensor * src0 = dst->src[0];
 8095
 8096    switch (src0->type) {
 8097        case GGML_TYPE_F32:
 8098            {
 8099                ggml_compute_forward_top_k_f32(params, dst);
 8100            } break;
 8101        default:
 8102            {
 8103                GGML_ABORT("fatal error");
 8104            }
 8105    }
 8106}
 8107
 8108static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 8109        const ggml_compute_params * params,
 8110        ggml_tensor * dst,
 8111        int ir0, int ir1,
 8112        int64_t ic_start, int64_t ic_end,
 8113        float * partials, int64_t partial_stride) {
 8114
 8115    const bool write_partials = (partials != nullptr);
 8116    const ggml_tensor * q     = dst->src[0];
 8117    const ggml_tensor * k     = dst->src[1];
 8118    const ggml_tensor * v     = dst->src[2];
 8119    const ggml_tensor * mask  = dst->src[3];
 8120    const ggml_tensor * sinks = dst->src[4];
 8121
 8122    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
 8123    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
 8124    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
 8125    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
 8126    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
 8127    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
 8128    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
 8129    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 8130
 8131    const int64_t DK = nek0;
 8132    const int64_t DV = nev0;
 8133    const int64_t N  = neq1;
 8134
 8135    GGML_ASSERT(ne0 == DV);
 8136    GGML_ASSERT(ne2 == N);
 8137
 8138    // input tensor rows must be contiguous
 8139    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
 8140    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
 8141    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 8142
 8143    GGML_ASSERT(neq0 == DK);
 8144    GGML_ASSERT(nek0 == DK);
 8145    GGML_ASSERT(nev0 == DV);
 8146
 8147    GGML_ASSERT(neq1 == N);
 8148
 8149    // dst cannot be transposed or permuted
 8150    GGML_ASSERT(nb0 == sizeof(float));
 8151    GGML_ASSERT(nb0 <= nb1);
 8152    GGML_ASSERT(nb1 <= nb2);
 8153    GGML_ASSERT(nb2 <= nb3);
 8154
 8155    // broadcast factors
 8156    const int64_t rk2 = neq2/nek2;
 8157    const int64_t rk3 = neq3/nek3;
 8158
 8159    const int64_t rv2 = neq2/nev2;
 8160    const int64_t rv3 = neq3/nev3;
 8161
 8162    // parallelize by q rows using ggml_vec_dot_f32
 8163
 8164    float scale         = 1.0f;
 8165    float max_bias      = 0.0f;
 8166    float logit_softcap = 0.0f;
 8167
 8168    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
 8169    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
 8170    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
 8171
 8172    if (logit_softcap != 0) {
 8173        scale /= logit_softcap;
 8174    }
 8175
 8176    const uint32_t n_head      = neq2;
 8177    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
 8178
 8179    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 8180    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 8181
 8182    ggml_type         const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type;
 8183    ggml_from_float_t const q_to_vec_dot   = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float;
 8184    ggml_vec_dot_t    const kq_vec_dot     = ggml_get_type_traits_cpu(k->type)->vec_dot;
 8185    ggml_to_float_t   const v_to_float     = ggml_get_type_traits(v->type)->to_float;
 8186
 8187    GGML_ASSERT((                            q_to_vec_dot) && "fattn: unsupported K-type");
 8188    GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float  ) && "fattn: unsupported V-type");
 8189
 8190    int ith = params->ith;
 8191
 8192    for (int ir = ir0; ir < ir1; ++ir) {
 8193        // q indices
 8194        const int iq3 = ir/(neq2*neq1);
 8195        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
 8196        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
 8197
 8198        const uint32_t h = iq2; // head index
 8199        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
 8200
 8201        float S = 0.0f;      // sum
 8202        float M = -INFINITY; // maximum KQ value
 8203
 8204        float       * VKQ32 = (float       *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
 8205        float       * V32   =                 (VKQ32 + 1*DV); // (temporary) FP32 V buffer
 8206        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
 8207        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
 8208
 8209        if (v->type == GGML_TYPE_F16) {
 8210            memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
 8211        } else {
 8212            memset(VKQ32, 0, DV*sizeof(float));
 8213        }
 8214
 8215        const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]) : NULL;
 8216
 8217        // k indices
 8218        const int ik3 = iq3 / rk3;
 8219        const int ik2 = iq2 / rk2;
 8220
 8221        // v indices
 8222        const int iv3 = iq3 / rv3;
 8223        const int iv2 = iq2 / rv2;
 8224
 8225        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
 8226        q_to_vec_dot(pq, Q_q, DK);
 8227
 8228        // online softmax / attention
 8229        // loop over n_kv and n_head_kv
 8230        // ref: https://arxiv.org/pdf/2112.05682.pdf
 8231
 8232        for (int64_t ic = ic_start; ic < ic_end; ++ic) {
 8233            const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
 8234            if (mv == -INFINITY) {
 8235                continue;
 8236            }
 8237
 8238            float s; // KQ value
 8239
 8240            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
 8241            kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);
 8242
 8243            s = s*scale; // scale KQ value
 8244
 8245            if (logit_softcap != 0.0f) {
 8246                s = logit_softcap*tanhf(s);
 8247            }
 8248
 8249            s += mv; // apply mask
 8250
 8251            const float Mold = M;
 8252
 8253            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
 8254            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
 8255
 8256            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
 8257
 8258            if (v->type == GGML_TYPE_F16) {
 8259                if (s > M) {
 8260                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
 8261                    M = s;
 8262                    ms = expf(Mold - M);
 8263
 8264                    // V = V*expf(Mold - M)
 8265                    ggml_vec_scale_f16(DV, VKQ16, ms);
 8266                } else {
 8267                    // no new maximum, ms == 1.0f, vs != 1.0f
 8268                    vs = expf(s - M);
 8269                }
 8270
 8271                // V += v*expf(s - M)
 8272                ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
 8273            } else {
 8274                if (s > M) {
 8275                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
 8276                    M = s;
 8277                    ms = expf(Mold - M);
 8278
 8279                    // V = V*expf(Mold - M)
 8280                    ggml_vec_scale_f32(DV, VKQ32, ms);
 8281                } else {
 8282                    // no new maximum, ms == 1.0f, vs != 1.0f
 8283                    vs = expf(s - M);
 8284                }
 8285
 8286                // V += v*expf(s - M)
 8287                if (v_to_float) {
 8288                    v_to_float(v_data, V32, DV);
 8289                    ggml_vec_mad_f32(DV, VKQ32, V32, vs);
 8290                } else {
 8291                    // V is F32
 8292                    ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
 8293                }
 8294            }
 8295
 8296            S = S*ms + vs; // scale and increment sum with partial sum
 8297        }
 8298
 8299        if (v->type == GGML_TYPE_F16) {
 8300            for (int64_t d = 0; d < DV; ++d) {
 8301                VKQ32[d] = GGML_CPU_FP16_TO_FP32(VKQ16[d]);
 8302            }
 8303        }
 8304
 8305        // sinks - apply only on the first kv-chunk
 8306        if (sinks && ic_start == 0) {
 8307            const float s = ((float *)((char *) sinks->data))[h];
 8308
 8309            float ms = 1.0f;
 8310            float vs = 1.0f;
 8311
 8312            if (s > M) {
 8313                ms = expf(M - s);
 8314                M = s;
 8315                ggml_vec_scale_f32(DV, VKQ32, ms);
 8316            } else {
 8317                vs = expf(s - M);
 8318            }
 8319
 8320            S = S*ms + vs;
 8321        }
 8322
 8323        if (write_partials) {
 8324            // Write M, S, VKQ to partials for later reduction
 8325            // partials layout: [M, S, VKQ[DV]] per query head
 8326            float * partial = partials + ir * partial_stride;
 8327            partial[0] = M;
 8328            partial[1] = S;
 8329            memcpy(partial + 2, VKQ32, DV * sizeof(float));
 8330        } else {
 8331            // V /= S
 8332            const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
 8333            ggml_vec_scale_f32(DV, VKQ32, S_inv);
 8334
 8335            // dst indices
 8336            const int i1 = iq1;
 8337            const int i2 = iq2;
 8338            const int i3 = iq3;
 8339
 8340            // permute(0, 2, 1, 3)
 8341            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
 8342        }
 8343    }
 8344}
 8345
 8346static void ggml_compute_forward_flash_attn_ext_tiled(
 8347        const ggml_compute_params * params,
 8348        ggml_tensor * dst,
 8349        int ir0, int ir1) {
 8350    const ggml_tensor * q     = dst->src[0];
 8351    const ggml_tensor * k     = dst->src[1];
 8352    const ggml_tensor * v     = dst->src[2];
 8353    const ggml_tensor * mask  = dst->src[3];
 8354    const ggml_tensor * sinks = dst->src[4];
 8355
 8356    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
 8357    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
 8358    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
 8359    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
 8360    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
 8361    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
 8362    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
 8363    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 8364
 8365    const int64_t DK = nek0;
 8366    const int64_t DV = nev0;
 8367    const int64_t N  = neq1;
 8368
 8369    GGML_ASSERT(ne0 == DV);
 8370    GGML_ASSERT(ne2 == N);
 8371
 8372    // input tensor rows must be contiguous
 8373    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
 8374    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
 8375    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 8376
 8377    GGML_ASSERT(neq0 == DK);
 8378    GGML_ASSERT(nek0 == DK);
 8379    GGML_ASSERT(nev0 == DV);
 8380
 8381    GGML_ASSERT(neq1 == N);
 8382
 8383    // dst cannot be transposed or permuted
 8384    GGML_ASSERT(nb0 == sizeof(float));
 8385    GGML_ASSERT(nb0 <= nb1);
 8386    GGML_ASSERT(nb1 <= nb2);
 8387    GGML_ASSERT(nb2 <= nb3);
 8388
 8389    GGML_ASSERT(k->type == v->type);
 8390    const ggml_type kv_type = k->type;
 8391
 8392    const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type);
 8393    const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float;
 8394    const ggml_vec_dot_t    kv_vec_dot    = kv_type_traits_cpu->vec_dot;
 8395    const size_t kv_type_size = ggml_type_size(kv_type);
 8396
 8397    // broadcast factors
 8398    const int64_t rk2 = neq2/nek2;
 8399    const int64_t rk3 = neq3/nek3;
 8400
 8401    const int64_t rv2 = neq2/nev2;
 8402    const int64_t rv3 = neq3/nev3;
 8403
 8404    float scale         = 1.0f;
 8405    float max_bias      = 0.0f;
 8406    float logit_softcap = 0.0f;
 8407
 8408    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
 8409    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
 8410    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
 8411
 8412    if (logit_softcap != 0) {
 8413        scale /= logit_softcap;
 8414    }
 8415
 8416    const uint32_t n_head      = neq2;
 8417    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
 8418
 8419    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 8420    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 8421
 8422    int ith = params->ith;
 8423
 8424    static constexpr int Q_TILE_SZ  = ggml_fa_tile_config::Q;
 8425    static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV;
 8426
 8427    GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ");
 8428
 8429    int ir = ir0;
 8430    while (ir < ir1) {
 8431        // q indices for the start of this tile
 8432        const int iq3 = ir/(neq2*neq1);
 8433        const int iq2 = (ir - iq3*neq2*neq1)/neq1;
 8434        const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
 8435
 8436        // Number of valid rows in this tile:
 8437        // - limited by tile size (Q_TILE_SZ)
 8438        // - limited by chunk boundary (ir1 - ir)
 8439        // - limited by head boundary (neq1 - iq1) to avoid crossing into next head
 8440        const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1)));
 8441        GGML_ASSERT(tile_rows > 0);
 8442
 8443        const uint32_t h = iq2; // head index
 8444        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
 8445
 8446        float S[Q_TILE_SZ];
 8447        float M[Q_TILE_SZ];
 8448
 8449        for (int i = 0 ; i < Q_TILE_SZ; ++i) {
 8450            S[i] = 0.;
 8451            M[i] = -INFINITY;
 8452        }
 8453
 8454        // Per-thread scratch layout:
 8455        // Q_q:    Q_TILE_SZ * DK (converted Q tile in KV type)
 8456        // KQ:     Q_TILE_SZ * KV_TILE_SZ (attention scores in float)
 8457        // mask:   Q_TILE_SZ * KV_TILE_SZ (mask in float)
 8458        // VKQ32:  Q_TILE_SZ * DV (FP32 output accumulator)
 8459        // V32:    KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion)
 8460        float * base  = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32);
 8461
 8462        void  * Q_q    = base;
 8463        float * KQ     = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float));
 8464        float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ;
 8465        float * VKQ32  = mask32 + Q_TILE_SZ * KV_TILE_SZ;
 8466        float * V32    = VKQ32 + Q_TILE_SZ * DV;  // F32 buffer for V tile
 8467
 8468        memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float));
 8469        memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float));
 8470
 8471        // k indices
 8472        const int ik3 = iq3 / rk3;
 8473        const int ik2 = iq2 / rk2;
 8474
 8475        // v indices
 8476        const int iv3 = iq3 / rv3;
 8477        const int iv2 = iq2 / rv2;
 8478
 8479        for (int tq = 0; tq < tile_rows; tq++) {
 8480            const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3));
 8481            kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK);
 8482        }
 8483        // Zero-pad remaining rows
 8484        for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) {
 8485            memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size);
 8486        }
 8487
 8488        for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) {
 8489
 8490            // skip the tile entirely if all the masks are -inf
 8491            if (mask) {
 8492                bool can_skip = true;
 8493                for (int tq = 0; tq < tile_rows; tq++) {
 8494                    const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]);
 8495                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
 8496                        mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]);
 8497                        if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) {
 8498                            can_skip = false;
 8499                        }
 8500                    }
 8501                }
 8502
 8503                if (can_skip) {
 8504                    continue;
 8505                }
 8506            }
 8507
 8508            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
 8509                const void * q_row = (const char *)Q_q + tq * DK * kv_type_size;
 8510                for (int tk = 0; tk < KV_TILE_SZ; tk++) {
 8511                    const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3);
 8512                    float s;
 8513                    kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1);
 8514                    KQ[tq * KV_TILE_SZ + tk] = s * scale;
 8515                }
 8516            }
 8517
 8518            if (logit_softcap != 0.0f) {
 8519                ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ);
 8520                ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap);
 8521            }
 8522
 8523            if (mask) {
 8524                ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32);
 8525            }
 8526
 8527            bool skip[Q_TILE_SZ] = {};
 8528
 8529            for (int tq = 0; tq < Q_TILE_SZ; tq++) {
 8530                float * kq_row = KQ + tq * KV_TILE_SZ;
 8531
 8532                float tile_max;
 8533                ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row);
 8534
 8535                if (tile_max == -INFINITY) {
 8536                    skip[tq] = true;
 8537                    continue;
 8538                }
 8539
 8540                const float Mold = M[tq];
 8541                const float Mnew = fmaxf(Mold, tile_max);
 8542
 8543                if (Mnew > Mold) {
 8544                    const float ms = expf(Mold - Mnew);
 8545                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
 8546                    S[tq] *= ms;
 8547                }
 8548                M[tq] = Mnew;
 8549
 8550
 8551                S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew);
 8552            }
 8553
 8554            // Convert V tile to F32 first (if F16), then do MAD
 8555            // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster.
 8556            // TODO: on ARM, native f16 should be faster
 8557            if (kv_type == GGML_TYPE_F16) {
 8558                for (int tk = 0; tk < KV_TILE_SZ; tk++) {
 8559                    const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
 8560                    ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV);
 8561                }
 8562                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
 8563                    if (skip[tq]) continue;
 8564                    float * vkq_row = VKQ32 + tq * DV;
 8565                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
 8566                        const float p = KQ[tq * KV_TILE_SZ + tk];
 8567                        ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p);
 8568                    }
 8569                }
 8570            } else {
 8571                for (int tq = 0; tq < Q_TILE_SZ; tq++) {
 8572                    if (skip[tq]) continue;
 8573                    float * vkq_row = VKQ32 + tq * DV;
 8574                    for (int tk = 0; tk < KV_TILE_SZ; tk++) {
 8575                        const float p = KQ[tq * KV_TILE_SZ + tk];
 8576                        const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3));
 8577                        ggml_vec_mad_f32(DV, vkq_row, v_row, p);
 8578                    }
 8579                }
 8580            }
 8581        }
 8582
 8583        // sinks (apply only to valid rows in the tile)
 8584        if (sinks) {
 8585            const float s = ((float *)((char *) sinks->data))[h];
 8586
 8587            for (int tq = 0; tq < tile_rows; tq++) {
 8588                float ms = 1.0f;
 8589                float vs = 1.0f;
 8590
 8591                if (s > M[tq]) {
 8592                    ms = expf(M[tq] - s);
 8593                    ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms);
 8594                } else {
 8595                    vs = expf(s - M[tq]);
 8596                }
 8597
 8598                S[tq] = S[tq] * ms + vs;
 8599            }
 8600        }
 8601
 8602        for (int tq = 0; tq < tile_rows; tq++) {
 8603            // V /= S
 8604            const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq];
 8605            ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv);
 8606
 8607            // dst indices
 8608            const int i1 = iq1 + tq;
 8609            const int i2 = iq2;
 8610            const int i3 = iq3;
 8611
 8612            // permute(0, 2, 1, 3)
 8613            memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1);
 8614        }
 8615
 8616        ir += tile_rows;
 8617    }
 8618}
 8619
 8620// Reduction function: combines partial results across KV chunks
 8621// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
 8622static void ggml_flash_attn_ext_reduce_partials(
 8623        const ggml_compute_params * params,
 8624        ggml_tensor * dst,
 8625        const int64_t n_chunks,
 8626        const int64_t chunk_size) {
 8627
 8628    const ggml_tensor * q = dst->src[0];
 8629    const ggml_tensor * k = dst->src[1];
 8630    const ggml_tensor * v = dst->src[2];
 8631
 8632    const int64_t DK        = k->ne[0];
 8633    const int64_t DV        = v->ne[0];
 8634    const int64_t nek1      = k->ne[1];
 8635    const int64_t n_q_heads = q->ne[2];
 8636
 8637    const int ith = params->ith;
 8638    const int nth = params->nth;
 8639
 8640    const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
 8641    float *       thread_wdata     = (float *) params->wdata + ith * wdata_per_thread;
 8642
 8643    const int64_t partials_offset  = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
 8644    const int64_t partial_size     = 2 + DV;
 8645    const float * partials_base    = (const float *) params->wdata + partials_offset;
 8646
 8647    // Output layout
 8648    const int64_t ne1 = dst->ne[1];
 8649    const int64_t ne2 = dst->ne[2];
 8650    const size_t  nb1 = dst->nb[1];
 8651
 8652    // Each thread reduces a subset of query heads
 8653    for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
 8654        float   M_final   = -INFINITY;
 8655        float   S_final   = 0.0f;
 8656        float * VKQ_final = thread_wdata;
 8657        memset(VKQ_final, 0, DV * sizeof(float));
 8658
 8659        // Combine partials from all chunks
 8660        for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
 8661            const int64_t ic_start = chunk_idx * chunk_size;
 8662            if (ic_start >= nek1) continue;
 8663
 8664            const float * partial   = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
 8665            const float   M_chunk   = partial[0];
 8666            const float   S_chunk   = partial[1];
 8667            const float * VKQ_chunk = partial + 2;
 8668
 8669            if (S_chunk == 0.0f) continue;
 8670
 8671            const float M_new     = fmaxf(M_final, M_chunk);
 8672            const float scale_old = expf(M_final - M_new);
 8673            const float scale_new = expf(M_chunk - M_new);
 8674
 8675            for (int64_t d = 0; d < DV; ++d) {
 8676                VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
 8677            }
 8678            S_final = S_final * scale_old + S_chunk * scale_new;
 8679            M_final = M_new;
 8680        }
 8681
 8682        // Normalize and write to output
 8683        if (S_final != 0.0f) {
 8684            const float S_inv = 1.0f / S_final;
 8685            ggml_vec_scale_f32(DV, VKQ_final, S_inv);
 8686        }
 8687        // iq1=0, iq3=0 for decode
 8688        memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
 8689    }
 8690}
 8691
 8692static void ggml_compute_forward_flash_attn_ext_f16(
 8693        const ggml_compute_params * params,
 8694        ggml_tensor * dst) {
 8695
 8696    const ggml_tensor * q     = dst->src[0];
 8697    const ggml_tensor * k     = dst->src[1];
 8698    const ggml_tensor * v     = dst->src[2];
 8699
 8700    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
 8701    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
 8702    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
 8703    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
 8704    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
 8705    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
 8706    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
 8707    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 8708
 8709    const int64_t DK = nek0;
 8710    const int64_t DV = nev0;
 8711    const int64_t N  = neq1;
 8712
 8713
 8714    GGML_ASSERT(ne0 == DV);
 8715    GGML_ASSERT(ne2 == N);
 8716
 8717    // input tensor rows must be contiguous
 8718    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
 8719    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
 8720    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 8721
 8722    GGML_ASSERT(neq0 == DK);
 8723    GGML_ASSERT(nek0 == DK);
 8724    GGML_ASSERT(nev0 == DV);
 8725
 8726    GGML_ASSERT(neq1 == N);
 8727
 8728    // dst cannot be transposed or permuted
 8729    GGML_ASSERT(nb0 == sizeof(float));
 8730    GGML_ASSERT(nb0 <= nb1);
 8731    GGML_ASSERT(nb1 <= nb2);
 8732    GGML_ASSERT(nb2 <= nb3);
 8733
 8734    const int ith = params->ith;
 8735    const int nth = params->nth;
 8736
 8737    // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
 8738    const bool use_ref = params->use_ref;
 8739
 8740    const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
 8741    const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
 8742
 8743    if (use_split_kv_path) {
 8744        const int64_t chunk_size = (nek1 + nth - 1) / nth;
 8745
 8746        // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
 8747        const int64_t partial_size  = 2 + DV;
 8748        float *       partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
 8749
 8750        const int64_t ic_start = ith * chunk_size;
 8751        const int64_t ic_end   = std::min(ic_start + chunk_size, nek1);
 8752
 8753        const int64_t partial_stride = nth * partial_size;
 8754        float *       chunk_partials = partials_base + ith * partial_size;
 8755
 8756        if (ic_start < nek1) {
 8757            for (int64_t q_head = 0; q_head < neq2; q_head++) {
 8758                ggml_compute_forward_flash_attn_ext_f16_one_chunk(
 8759                    params, dst, q_head, q_head + 1, ic_start, ic_end,
 8760                    chunk_partials, partial_stride);
 8761            }
 8762        } else {
 8763            for (int64_t q_head = 0; q_head < neq2; q_head++) {
 8764                float * q_partials = chunk_partials + q_head * partial_stride;
 8765                q_partials[0] = -INFINITY;  // M
 8766                q_partials[1] = 0.0f;       // S
 8767            }
 8768        }
 8769
 8770        ggml_barrier(params->threadpool);
 8771        ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
 8772    } else {
 8773
 8774        // total rows in q
 8775        const int64_t nr = neq1*neq2*neq3;
 8776
 8777        // disable for NUMA
 8778        const bool disable_chunking = ggml_is_numa();
 8779
 8780        // 4x chunks per thread
 8781        int nth_scaled = nth * 4;
 8782        int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
 8783        int64_t nchunk     = (nr + chunk_size - 1) / chunk_size;
 8784
 8785        if (nth == 1 || nchunk < nth || disable_chunking) {
 8786            nchunk = nth;
 8787        }
 8788
 8789        if (ith == 0) {
 8790            ggml_threadpool_chunk_set(params->threadpool, nth);
 8791        }
 8792
 8793        ggml_barrier(params->threadpool);
 8794
 8795        const int64_t dr = (nr + nchunk - 1) / nchunk;
 8796
 8797        static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
 8798        static constexpr int64_t Q_TILE_SZ  = ggml_fa_tile_config::Q;
 8799        const bool use_tiled = !use_ref &&
 8800                               (q->type == GGML_TYPE_F32 &&
 8801                                kv_is_f32_or_f16 &&
 8802                                k->type == v->type &&
 8803                                nek1 % KV_TILE_SZ == 0 &&
 8804                                neq1 >= Q_TILE_SZ);
 8805
 8806        int current_chunk = ith;
 8807
 8808        while (current_chunk < nchunk) {
 8809            const int64_t ir0 = dr * current_chunk;
 8810            const int64_t ir1 = MIN(ir0 + dr, nr);
 8811
 8812            if (use_tiled) {
 8813                ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
 8814            } else {
 8815                ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
 8816            }
 8817
 8818            current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
 8819        }
 8820    }
 8821}
 8822
 8823void ggml_compute_forward_flash_attn_ext(
 8824        const ggml_compute_params * params,
 8825        ggml_tensor * dst) {
 8826    switch (dst->op_params[3]) {
 8827        case GGML_PREC_DEFAULT:
 8828        case GGML_PREC_F32:
 8829            {
 8830                // uses F32 accumulators
 8831                ggml_compute_forward_flash_attn_ext_f16(params, dst);
 8832            } break;
 8833        default:
 8834            {
 8835                GGML_ABORT("fatal error");
 8836            }
 8837    }
 8838}
 8839
 8840// ggml_compute_forward_flash_attn_back
 8841
 8842static void ggml_compute_forward_flash_attn_back_f32(
 8843        const ggml_compute_params * params,
 8844        const bool masked,
 8845              ggml_tensor * dst) {
 8846
 8847    const ggml_tensor * q = dst->src[0];
 8848    const ggml_tensor * k = dst->src[1];
 8849    const ggml_tensor * v = dst->src[2];
 8850    const ggml_tensor * d = dst->src[3];
 8851
 8852    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
 8853    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
 8854    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
 8855    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
 8856    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
 8857    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
 8858    GGML_TENSOR_LOCALS(int64_t, ned, d,   ne)
 8859    GGML_TENSOR_LOCALS(size_t,  nbd, d,   nb)
 8860    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
 8861    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 8862
 8863    const int ith = params->ith;
 8864    const int nth = params->nth;
 8865
 8866    const int64_t D = neq0;
 8867    const int64_t N = neq1;
 8868    const int64_t P = nek1 - N;
 8869    const int64_t M = P + N;
 8870
 8871    const int Mup  = ggml_up(M, GGML_SOFT_MAX_UNROLL);
 8872    const int mxDM = MAX(D, Mup);
 8873
 8874    // GGML_ASSERT(ne0 == D);
 8875    // GGML_ASSERT(ne1 == N);
 8876    GGML_ASSERT(P >= 0);
 8877
 8878    GGML_ASSERT(nbq0 == sizeof(float));
 8879    GGML_ASSERT(nbk0 == sizeof(float));
 8880    GGML_ASSERT(nbv0 == sizeof(float));
 8881
 8882    GGML_ASSERT(neq0 == D);
 8883    GGML_ASSERT(nek0 == D);
 8884    GGML_ASSERT(nev1 == D);
 8885    GGML_ASSERT(ned0 == D);
 8886
 8887    GGML_ASSERT(neq1 == N);
 8888    GGML_ASSERT(nek1 == N + P);
 8889    GGML_ASSERT(nev1 == D);
 8890    GGML_ASSERT(ned1 == N);
 8891
 8892    // dst cannot be transposed or permuted
 8893    GGML_ASSERT(nb0 == sizeof(float));
 8894    GGML_ASSERT(nb0 <= nb1);
 8895    GGML_ASSERT(nb1 <= nb2);
 8896    GGML_ASSERT(nb2 <= nb3);
 8897
 8898    if (ith == 0) {
 8899        memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
 8900    }
 8901    ggml_barrier(params->threadpool);
 8902
 8903    const int64_t elem_q = ggml_nelements(q);
 8904    const int64_t elem_k = ggml_nelements(k);
 8905
 8906    ggml_type result_type = dst->type;
 8907    GGML_ASSERT(ggml_blck_size(result_type) == 1);
 8908    const size_t tsize = ggml_type_size(result_type);
 8909
 8910    const size_t offs_q = 0;
 8911    const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
 8912    const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
 8913
 8914    void * grad_q = (char *) dst->data;
 8915    void * grad_k = (char *) dst->data + offs_k;
 8916    void * grad_v = (char *) dst->data + offs_v;
 8917
 8918    const size_t nbgq1 = nb0*neq0;
 8919    const size_t nbgq2 = nb0*neq0*neq1;
 8920    const size_t nbgq3 = nb0*neq0*neq1*neq2;
 8921
 8922    const size_t nbgk1 = nb0*nek0;
 8923    const size_t nbgk2 = nb0*nek0*nek1;
 8924    const size_t nbgk3 = nb0*nek0*nek1*neq2;
 8925
 8926    const size_t nbgv1 = nb0*nev0;
 8927    const size_t nbgv2 = nb0*nev0*nev1;
 8928    const size_t nbgv3 = nb0*nev0*nev1*neq2;
 8929
 8930    // parallelize by k rows using ggml_vec_dot_f32
 8931
 8932    // total rows in k
 8933    const int nr = nek2*nek3;
 8934
 8935    // rows per thread
 8936    const int dr = (nr + nth - 1)/nth;
 8937
 8938    // row range for this thread
 8939    const int ir0 = dr*ith;
 8940    const int ir1 = MIN(ir0 + dr, nr);
 8941
 8942    const float scale = 1.0f/sqrtf(D);
 8943
 8944    //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
 8945
 8946    // how often k2 (and v2) is repeated in q2
 8947    int nrep = neq2/nek2;
 8948
 8949    for (int ir = ir0; ir < ir1; ++ir) {
 8950        // q indices
 8951        const int ik3 = ir/(nek2);
 8952        const int ik2 = ir - ik3*nek2;
 8953
 8954        const int iq3 = ik3;
 8955        const int id3 = ik3;
 8956        const int iv3 = ik3;
 8957        const int iv2 = ik2;
 8958
 8959        for (int irep = 0; irep < nrep; ++irep) {
 8960            const int iq2 = ik2 + irep*nek2;
 8961            const int id2 = iq2;
 8962
 8963            // (ik2 + irep*nek2) % nek2 == ik2
 8964            for (int iq1 = 0; iq1 < neq1; ++iq1) {
 8965                const int id1 = iq1;
 8966
 8967                // not sure about CACHE_LINE_SIZE_F32..
 8968                // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
 8969                float * S  = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
 8970                float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
 8971
 8972                for (int i = M; i < Mup; ++i) {
 8973                    S[i] = -INFINITY;
 8974                }
 8975
 8976                const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
 8977                for (int64_t ic = 0; ic < masked_begin; ++ic) {
 8978                    // k indices
 8979                    const int ik1 = ic;
 8980
 8981                    // S indices
 8982                    const int i1 = ik1;
 8983
 8984                    ggml_vec_dot_f32(neq0,
 8985                            S + i1, 0,
 8986                            (float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
 8987                            (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
 8988                }
 8989
 8990                // scale
 8991                ggml_vec_scale_f32(masked_begin, S, scale);
 8992
 8993                for (int64_t i = masked_begin; i < M; i++) {
 8994                    S[i] = -INFINITY;
 8995                }
 8996
 8997                // softmax
 8998                // exclude known -INF S[..] values from max and loop
 8999                // dont forget to set their SM values to zero
 9000                {
 9001                    float max = -INFINITY;
 9002                    ggml_vec_max_f32(masked_begin, &max, S);
 9003
 9004                    ggml_float sum = 0.0;
 9005                    {
 9006#ifdef GGML_SOFT_MAX_ACCELERATE
 9007                        max = -max;
 9008                        vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
 9009                        vvexpf(SM, SM, &Mup);
 9010                        ggml_vec_sum_f32(Mup, &sum, SM);
 9011#else
 9012                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
 9013#endif
 9014                    }
 9015
 9016                    assert(sum > 0.0);
 9017
 9018                    sum = 1.0/sum;
 9019                    ggml_vec_scale_f32(masked_begin, SM, sum);
 9020
 9021                }
 9022
 9023                // step-by-step explanation
 9024                {
 9025                    // forward-process                    shape      grads from backward process
 9026                    // parallel_for ik2,ik3:
 9027                    //  for irep:
 9028                    //   iq2 = ik2 + irep*nek2
 9029                    //   k[:D,:M,:,:]                     [D,M,:,:]  grad[k][:D,:M,ik2,ik3]  += grad[kcur]
 9030                    //   q[:D,:N,:,:]                     [D,N,:,:]  grad[q][:D,iq1,iq2,iq3] += grad[qcur]
 9031                    //   v[:M,:D,:,:]                     [M,D,:,:]  grad[v][:M,:D,iv2,iv3]  += grad[vcur]
 9032                    //   for iq1:
 9033                    //    kcur   = k[:D,:M,ik2,ik3]       [D,M,1,1]  grad[kcur] = grad[S1].T @ qcur
 9034                    //    qcur   = q[:D,iq1,iq2,iq3]      [D,1,1,1]  grad[qcur] = grad[S1]   @ kcur
 9035                    //    vcur   = v[:M,:D,iv2,iv3]       [M,D,1,1]  grad[vcur] = grad[S5].T @ S4
 9036                    //    S0     = -Inf                   [D,1,1,1]
 9037                    //   ~S1[i]  = dot(kcur[:D,i], qcur)
 9038                    //    S1     = qcur @ kcur.T          [M,1,1,1]  grad[S1]   = grad[S2] * scale
 9039                    //    S2     = S1 * scale             [M,1,1,1]  grad[S2]   = diag_mask_zero(grad[S3], P)
 9040                    //    S3     = diag_mask_inf(S2, P)   [M,1,1,1]  grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
 9041                    //    S4     = softmax(S3)            [M,1,1,1]  grad[S4]   = grad[S5] @ vcur
 9042                    //   ~S5[i]  = dot(vcur[:,i], S4)
 9043                    //    S5     = S4 @ vcur.T            [D,1,1,1]  grad[S5]   = d[:D,id1,id2,id3]
 9044                    //   ~dst[i,iq1,iq2,iq3]  = S5[i]              ^
 9045                    //    dst[:D,iq1,iq2,iq3] = S5                 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
 9046                    // dst                               backward-/ grad[dst]                 = d
 9047                    //
 9048                    // output gradients with their dependencies:
 9049                    //
 9050                    // grad[kcur] = grad[S1].T @ qcur
 9051                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
 9052                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
 9053                    // grad[S4]   = grad[S5] @ vcur
 9054                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
 9055                    // grad[qcur] = grad[S1]   @ kcur
 9056                    // grad[vcur] = grad[S5].T @ S4
 9057                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
 9058                    //
 9059                    // in post-order:
 9060                    //
 9061                    // S1         = qcur @ kcur.T
 9062                    // S2         = S1 * scale
 9063                    // S3         = diag_mask_inf(S2, P)
 9064                    // S4         = softmax(S3)
 9065                    // grad[S4]   = d[:D,id1,id2,id3] @ vcur
 9066                    // grad[S3]   = S4 * (grad[S4] - dot(S4, grad[S4]))
 9067                    // grad[S1]   = diag_mask_zero(grad[S3], P) * scale
 9068                    // grad[qcur] = grad[S1]   @ kcur
 9069                    // grad[kcur] = grad[S1].T @ qcur
 9070                    // grad[vcur] = d[:D,id1,id2,id3].T @ S4
 9071                    //
 9072                    // using less variables (SM=S4):
 9073                    //
 9074                    // S             = diag_mask_inf(qcur @ kcur.T * scale, P)
 9075                    // SM            = softmax(S)
 9076                    // S             = d[:D,iq1,iq2,iq3] @ vcur
 9077                    // dot_SM_gradSM = dot(SM, S)
 9078                    // S             = SM * (S - dot(SM, S))
 9079                    // S             = diag_mask_zero(S, P) * scale
 9080                    //
 9081                    // grad[q][:D,iq1,iq2,iq3] += S   @ kcur
 9082                    // grad[k][:D,:M,ik2,ik3]  += S.T @ qcur
 9083                    // grad[v][:M,:D,iv2,iv3]  += d[:D,id1,id2,id3].T @ SM
 9084                }
 9085
 9086                // S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
 9087                // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
 9088                // for ic:
 9089                //   S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
 9090                // exclude known future zero S[..] values from operation
 9091                ggml_vec_set_f32(masked_begin, S, 0);
 9092                for (int64_t ic = 0; ic < D; ++ic) {
 9093                    ggml_vec_mad_f32(masked_begin,
 9094                            S,
 9095                             (float *) ((char *) v->data + (          ic*nbv1  + iv2*nbv2 + iv3*nbv3)),
 9096                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
 9097                }
 9098
 9099                // S = SM * (S - dot(SM, S))
 9100                float dot_SM_gradSM = 0;
 9101                ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
 9102                ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
 9103                ggml_vec_mul_f32 (masked_begin, S, S, SM);
 9104
 9105                // S = diag_mask_zero(S, P) * scale
 9106                // already done by above ggml_vec_set_f32
 9107
 9108                // exclude known zero S[..] values from operation
 9109                ggml_vec_scale_f32(masked_begin, S, scale);
 9110
 9111                // S    shape [M,1]
 9112                // SM   shape [M,1]
 9113                // kcur shape [D,M]
 9114                // qcur shape [D,1]
 9115                // vcur shape [M,D]
 9116
 9117                // grad[q][:D,iq1,iq2,iq3] += S @ kcur
 9118                // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
 9119                // for ic:
 9120                //  grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
 9121                // exclude known zero S[..] values from loop
 9122                for (int64_t ic = 0; ic < masked_begin; ++ic) {
 9123                    ggml_vec_mad_f32(D,
 9124                            (float *) ((char *) grad_q  + (iq1*nbgq1 + iq2*nbgq2  + iq3*nbgq3)),
 9125                            (float *) ((char *) k->data + (ic*nbk1   + ik2*nbk2   + ik3*nbk3)),
 9126                            S[ic]);
 9127                }
 9128
 9129                // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
 9130                // for ic:
 9131                //  grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
 9132                //  grad[k][:D,ic,iq2,iq3] += S[ic]     * qcur[:D,0]
 9133                // exclude known zero S[..] values from loop
 9134                for (int64_t ic = 0; ic < masked_begin; ++ic) {
 9135                    ggml_vec_mad_f32(D,
 9136                            (float *) ((char *) grad_k  + (ic*nbgk1  + ik2*nbgk2  + ik3*nbgk3)),
 9137                            (float *) ((char *) q->data + (iq1*nbq1  + iq2*nbq2   + iq3*nbq3)),
 9138                            S[ic]);
 9139                }
 9140
 9141                // grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T       @ SM
 9142                // for ic:
 9143                //  grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
 9144                //  grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3]         * SM[:M]
 9145                // exclude known zero SM[..] values from mad
 9146                for (int64_t ic = 0; ic < D; ++ic) {
 9147                    ggml_vec_mad_f32(masked_begin,
 9148                            (float *) ((char *) grad_v   + (          ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
 9149                            SM,
 9150                            *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2  + id3*nbd3)));
 9151                }
 9152            }
 9153        }
 9154    }
 9155}
 9156
 9157void ggml_compute_forward_flash_attn_back(
 9158        const ggml_compute_params * params,
 9159        const bool masked,
 9160        ggml_tensor * dst) {
 9161
 9162    const ggml_tensor * q = dst->src[0];
 9163
 9164    switch (q->type) {
 9165        case GGML_TYPE_F32:
 9166            {
 9167                ggml_compute_forward_flash_attn_back_f32(params, masked, dst);
 9168            } break;
 9169        default:
 9170            {
 9171                GGML_ABORT("fatal error");
 9172            }
 9173    }
 9174}
 9175
 9176// ggml_compute_forward_ssm_conv
 9177
 9178static void ggml_compute_forward_ssm_conv_f32(
 9179        const ggml_compute_params * params,
 9180        ggml_tensor * dst) {
 9181    const ggml_tensor * src0 = dst->src[0]; // conv_x
 9182    const ggml_tensor * src1 = dst->src[1]; // conv1d.weight
 9183
 9184    const int ith = params->ith;
 9185    const int nth = params->nth;
 9186
 9187    const int nc  = src1->ne[0]; // d_conv
 9188    const int ncs = src0->ne[0]; // d_conv - 1 + n_t
 9189    const int nr  = src0->ne[1]; // d_inner
 9190    const int n_t =  dst->ne[1]; // tokens per sequence
 9191    const int n_s =  dst->ne[2]; // number of sequences in the batch
 9192
 9193    GGML_ASSERT( dst->ne[0] == nr);
 9194    GGML_ASSERT(src0->nb[0] == sizeof(float));
 9195    GGML_ASSERT(src1->nb[0] == sizeof(float));
 9196    GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float));
 9197
 9198    // rows per thread
 9199    const int dr = (nr + nth - 1)/nth;
 9200
 9201    // row range for this thread
 9202    const int ir0 = dr*ith;
 9203    const int ir1 = MIN(ir0 + dr, nr);
 9204    const int ir  = ir1 - ir0;
 9205
 9206    for (int i3 = 0; i3 < n_s; ++i3) {
 9207        for (int i2 = 0; i2 < n_t; ++i2) {
 9208            // {d_conv - 1 + n_t, d_inner, n_seqs}
 9209            // sliding window
 9210            const float * s = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i2*(src0->nb[0]) + i3*(src0->nb[2])); // {d_conv, d_inner, n_s}
 9211            const float * c = (const float *) ((const char *) src1->data + ir0*(src1->nb[1])); // {d_conv, d_inner}
 9212            float * x = (float *) ((char *) dst->data + ir0*(dst->nb[0]) + i2*(dst->nb[1]) + i3*(dst->nb[2])); // {d_inner, n_t, n_s}
 9213
 9214            // TODO: transpose the output for smaller strides for big batches?
 9215            // d_inner
 9216            for (int i1 = 0; i1 < ir; ++i1) {
 9217                // rowwise dot product
 9218                // NOTE: not using ggml_vec_dot_f32, because its sum is in double precision
 9219                float sumf = 0.0f;
 9220
 9221                // d_conv
 9222                for (int i0 = 0; i0 < nc; ++i0) {
 9223                    sumf += s[i0 + i1*ncs] * c[i0 + i1*nc];
 9224                }
 9225                x[i1] = sumf;
 9226            }
 9227        }
 9228    }
 9229}
 9230
 9231void ggml_compute_forward_ssm_conv(
 9232        const ggml_compute_params * params,
 9233        ggml_tensor * dst) {
 9234    switch (dst->src[0]->type) {
 9235        case GGML_TYPE_F32:
 9236            {
 9237                ggml_compute_forward_ssm_conv_f32(params, dst);
 9238            } break;
 9239        default:
 9240            {
 9241                GGML_ABORT("fatal error");
 9242            }
 9243    }
 9244}
 9245
 9246// ggml_compute_forward_ssm_scan
 9247
 9248static void ggml_compute_forward_ssm_scan_f32(
 9249        const ggml_compute_params * params,
 9250        ggml_tensor * dst) {
 9251    const ggml_tensor * src0 = dst->src[0]; // s  {d_state, dim, n_head, n_seqs+}
 9252    const ggml_tensor * src1 = dst->src[1]; // x  {dim, n_head, n_seq_tokens, n_seqs}
 9253    const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs}
 9254    const ggml_tensor * src3 = dst->src[3]; // A  {d_state, n_head} or {1, n_head}
 9255    const ggml_tensor * src4 = dst->src[4]; // B  {d_state, n_group, n_seq_tokens, n_seqs}
 9256    const ggml_tensor * src5 = dst->src[5]; // C  {d_state, n_group, n_seq_tokens, n_seqs}
 9257    const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs}
 9258
 9259    const int ith = params->ith;
 9260    const int nth = params->nth;
 9261
 9262    const int64_t nc = src0->ne[0]; // d_state
 9263    const int64_t nr = src0->ne[1]; // dim
 9264    const int64_t nh = src1->ne[1]; // n_head
 9265    const int64_t ng = src4->ne[1];
 9266    const int64_t nt = src1->ne[2]; // number of tokens per sequence
 9267    const int64_t ns = src1->ne[3]; // number of sequences in the batch
 9268
 9269    // can't use ggml_nbytes because src1 is not necessarily contiguous
 9270    const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1);
 9271
 9272    GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst));
 9273    GGML_ASSERT(src0->nb[0] == sizeof(float));
 9274    GGML_ASSERT(src1->nb[0] == sizeof(float));
 9275    GGML_ASSERT(src2->nb[0] == sizeof(float));
 9276    GGML_ASSERT(src3->nb[0] == sizeof(float));
 9277    GGML_ASSERT(src4->nb[0] == sizeof(float));
 9278    GGML_ASSERT(src5->nb[0] == sizeof(float));
 9279    GGML_ASSERT(src6->nb[0] == sizeof(int32_t));
 9280    GGML_ASSERT(nh % ng == 0);
 9281
 9282    // heads per thread
 9283    const int dh = (nh + nth - 1)/nth;
 9284
 9285    // head range for this thread
 9286    const int ih0 = dh*ith;
 9287    const int ih1 = MIN(ih0 + dh, nh);
 9288
 9289    const int32_t * ids = (const int32_t *) src6->data;
 9290
 9291    for (int i3 = 0; i3 < ns; ++i3) {
 9292        const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns}
 9293              float * s  = (      float *) ((      char *) dst->data  + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns}
 9294
 9295        for (int i2 = 0; i2 < nt; ++i2) {
 9296            const float * x  = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns}
 9297            const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns}
 9298            const float * A  = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh}
 9299            const float * B  = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns}
 9300            const float * C  = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns}
 9301                  float * y  = (      float *) ((      char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns}
 9302
 9303            if (src3->ne[0] == 1) {
 9304                // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop
 9305
 9306                // n_head
 9307                for (int h = ih0; h < ih1; ++h) {
 9308                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
 9309                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
 9310                    const float dA = expf(dt_soft_plus * A[h]);
 9311                    const int g = h / (nh / ng); // repeat_interleave
 9312
 9313                    // dim
 9314                    for (int i1 = 0; i1 < nr; ++i1) {
 9315                        const int ii = i1 + h*nr;
 9316                        const float x_dt = x[ii] * dt_soft_plus;
 9317                        float sumf = 0.0f;
 9318#if defined(GGML_SIMD)
 9319    #if defined(__ARM_FEATURE_SVE)
 9320                        const int ggml_f32_epr = svcntw();
 9321                        const int ggml_f32_step = 1 * ggml_f32_epr;
 9322
 9323                        const int np = (nc & ~(ggml_f32_step - 1));
 9324
 9325                        GGML_F32_VEC sum = GGML_F32_VEC_ZERO;
 9326
 9327                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
 9328                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
 9329
 9330                        for (int i = 0; i < np; i += ggml_f32_step) {
 9331                            // TODO: maybe unroll more?
 9332                            for (int j = 0; j < 1; j++) {
 9333                                GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc);
 9334                                GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + g*nc);
 9335                                GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + g*nc);
 9336
 9337                                t0 = GGML_F32_VEC_MUL(t0, adA);
 9338                                t1 = GGML_F32_VEC_MUL(t1, axdt);
 9339
 9340                                t0 = GGML_F32_VEC_ADD(t0, t1);
 9341
 9342                                sum = GGML_F32_VEC_FMA(sum, t0, t2);
 9343
 9344                                GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0);
 9345                            }
 9346                        }
 9347
 9348                        sumf = GGML_F32xt_REDUCE_ONE(sum);
 9349    #elif defined(__riscv_v_intrinsic)
 9350                        // todo: RVV implementation
 9351                        const int np = 0;
 9352    #else
 9353                        const int np = (nc & ~(GGML_F32_STEP - 1));
 9354
 9355                        GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
 9356
 9357                        GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA);
 9358                        GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt);
 9359
 9360                        GGML_F32_VEC ax[GGML_F32_ARR];
 9361                        GGML_F32_VEC ay[GGML_F32_ARR];
 9362                        GGML_F32_VEC az[GGML_F32_ARR];
 9363
 9364                        for (int i = 0; i < np; i += GGML_F32_STEP) {
 9365                            for (int j = 0; j < GGML_F32_ARR; j++) {
 9366                                ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc);
 9367                                ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + g*nc);
 9368                                az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + g*nc);
 9369
 9370                                ax[j] = GGML_F32_VEC_MUL(ax[j], adA);
 9371                                ay[j] = GGML_F32_VEC_MUL(ay[j], axdt);
 9372
 9373                                ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]);
 9374
 9375                                sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]);
 9376
 9377                                GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]);
 9378                            }
 9379                        }
 9380
 9381                        // reduce sum0..sum3 to sum0
 9382                        GGML_F32_VEC_REDUCE(sumf, sum);
 9383    #endif
 9384#else
 9385                        const int np = 0;
 9386#endif
 9387                        // d_state
 9388                        for (int i0 = np; i0 < nc; ++i0) {
 9389                            const int i = i0 + ii*nc;
 9390                            const int ig = i0 + g*nc;
 9391                            // state = prev_state * dA + dB * x
 9392                            const float state = (s0[i] * dA) + (B[ig] * x_dt);
 9393                            // y = rowwise_dotprod(state, C)
 9394                            sumf += state * C[ig];
 9395                            s[i] = state;
 9396                        }
 9397                        y[ii] = sumf;
 9398                    }
 9399                }
 9400            } else {
 9401                // Mamba-1 has an element-wise decay factor for the states
 9402
 9403                // n_head
 9404                for (int h = ih0; h < ih1; ++h) {
 9405                    // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
 9406                    const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
 9407                    const int g = h / (nh / ng); // repeat_interleave
 9408
 9409                    // dim
 9410                    for (int i1 = 0; i1 < nr; ++i1) {
 9411                        const int ii = i1 + h*nr;
 9412                        const float x_dt = x[ii] * dt_soft_plus;
 9413#if defined(__ARM_FEATURE_SVE)
 9414                        svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt);
 9415                        svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus);
 9416                        svfloat32_t r1_vector = GGML_F32_VEC_ZERO;
 9417
 9418                        // d_state
 9419                        // TODO: what happens when (d_state % svcntw()) != 0?
 9420                        for (int64_t k = 0; k < nc; k += svcntw()) {
 9421                            svfloat32_t vA = GGML_F32_VEC_LOAD(&A[h*nc + k]);
 9422                            svfloat32_t vB = GGML_F32_VEC_LOAD(&B[k + g*nc]);
 9423                            svfloat32_t vC = GGML_F32_VEC_LOAD(&C[k + g*nc]);
 9424                            svfloat32_t vs0 = GGML_F32_VEC_LOAD(&s0[ii*nc + k]);
 9425
 9426                            svfloat32_t t1 = GGML_F32_VEC_MUL(vdt_soft_plus, vA);
 9427                            t1 = exp_ps_sve(svptrue_b32(), t1);
 9428                            svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB);
 9429
 9430                            vs0 = GGML_F32_VEC_FMA(t2, vs0, t1);
 9431                            r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector);
 9432
 9433                            GGML_F32_VEC_STORE(&s[ii*nc + k], vs0);
 9434                        }
 9435                        y[ii] = GGML_F32xt_REDUCE_ONE(r1_vector);
 9436#else
 9437                        float sumf = 0.0f;
 9438                        // NOTE: can't really use GGML_SIMD here because d_state is usually 16
 9439                        //       and also because expf is used within the loop.
 9440                        // d_state
 9441                        for (int i0 = 0; i0 < nc; ++i0) {
 9442                            const int i = i0 + ii*nc;
 9443                            const int ig = i0 + g*nc;
 9444                            // state = prev_state * dA + dB * x
 9445                            const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt);
 9446                            // y = rowwise_dotprod(state, C)
 9447                            sumf += state * C[ig];
 9448                            s[i] = state;
 9449                        }
 9450                        y[ii] = sumf;
 9451#endif
 9452                    }
 9453                }
 9454            }
 9455            // use the output as the source when it's not the first token-wise iteration
 9456            s0 = s;
 9457        }
 9458    }
 9459}
 9460
 9461void ggml_compute_forward_ssm_scan(
 9462        const ggml_compute_params * params,
 9463        ggml_tensor * dst) {
 9464    switch (dst->src[0]->type) {
 9465        case GGML_TYPE_F32:
 9466            {
 9467                ggml_compute_forward_ssm_scan_f32(params, dst);
 9468            } break;
 9469        default:
 9470            {
 9471                GGML_ABORT("fatal error");
 9472            }
 9473    }
 9474}
 9475
 9476// ggml_compute_forward_win_part
 9477
 9478static void ggml_compute_forward_win_part_f32(
 9479        const ggml_compute_params * params,
 9480        ggml_tensor * dst) {
 9481    GGML_UNUSED(params);
 9482
 9483    const ggml_tensor * src0 = dst->src[0];
 9484
 9485    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 9486    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 9487
 9488    const int32_t nep0 = ((const int32_t *)(dst->op_params))[0];
 9489    const int32_t nep1 = ((const int32_t *)(dst->op_params))[1];
 9490    const int32_t w    = ((const int32_t *)(dst->op_params))[2];
 9491
 9492    assert(ne00 == ne0);
 9493    assert(ne3  == nep0*nep1);
 9494
 9495    // TODO: optimize / multi-thread
 9496    for (int py = 0; py < nep1; ++py) {
 9497        for (int px = 0; px < nep0; ++px) {
 9498            const int64_t i3 = py*nep0 + px;
 9499            for (int64_t i2 = 0; i2 < ne2; ++i2) {
 9500                for (int64_t i1 = 0; i1 < ne1; ++i1) {
 9501                    for (int64_t i0 = 0; i0 < ne0; ++i0) {
 9502                        const int64_t i02 = py*w + i2;
 9503                        const int64_t i01 = px*w + i1;
 9504                        const int64_t i00 = i0;
 9505
 9506                        const int64_t i = i3*ne2*ne1*ne0 + i2*ne1*ne0    + i1*ne0   + i0;
 9507                        const int64_t j =                  i02*ne01*ne00 + i01*ne00 + i00;
 9508
 9509                        if (py*w + i2 >= ne02 || px*w + i1 >= ne01) {
 9510                            ((float *) dst->data)[i] = 0.0f;
 9511                        } else {
 9512                            ((float *) dst->data)[i] = ((float *) src0->data)[j];
 9513                        }
 9514                    }
 9515                }
 9516            }
 9517        }
 9518    }
 9519}
 9520
 9521void ggml_compute_forward_win_part(
 9522        const ggml_compute_params * params,
 9523        ggml_tensor * dst) {
 9524
 9525    const ggml_tensor * src0 = dst->src[0];
 9526
 9527    switch (src0->type) {
 9528        case GGML_TYPE_F32:
 9529            {
 9530                ggml_compute_forward_win_part_f32(params, dst);
 9531            } break;
 9532        default:
 9533            {
 9534                GGML_ABORT("fatal error");
 9535            }
 9536    }
 9537}
 9538
 9539// ggml_compute_forward_win_unpart
 9540
 9541static void ggml_compute_forward_win_unpart_f32(
 9542        const ggml_compute_params * params,
 9543        ggml_tensor * dst) {
 9544    GGML_UNUSED(params);
 9545
 9546    const ggml_tensor * src0 = dst->src[0];
 9547
 9548    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
 9549    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne)
 9550
 9551    const int32_t w = ((const int32_t *)(dst->op_params))[0];
 9552
 9553    // padding
 9554    const int px = (w - ne1%w)%w;
 9555    //const int py = (w - ne2%w)%w;
 9556
 9557    const int npx = (px + ne1)/w;
 9558    //const int npy = (py + ne2)/w;
 9559
 9560    assert(ne0 == ne00);
 9561
 9562    // TODO: optimize / multi-thread
 9563    for (int64_t i2 = 0; i2 < ne2; ++i2) {
 9564        for (int64_t i1 = 0; i1 < ne1; ++i1) {
 9565            for (int64_t i0 = 0; i0 < ne0; ++i0) {
 9566                const int ip2 = i2/w;
 9567                const int ip1 = i1/w;
 9568
 9569                const int64_t i02 = i2%w;
 9570                const int64_t i01 = i1%w;
 9571                const int64_t i00 = i0;
 9572
 9573                const int64_t i = (ip2*npx + ip1)*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00 + i00;
 9574                const int64_t j =                                  i2*ne1*ne0    + i1*ne0   + i0;
 9575
 9576                ((float *) dst->data)[j] = ((float *) src0->data)[i];
 9577            }
 9578        }
 9579    }
 9580}
 9581
 9582void ggml_compute_forward_win_unpart(
 9583        const ggml_compute_params * params,
 9584        ggml_tensor * dst) {
 9585
 9586    const ggml_tensor * src0 = dst->src[0];
 9587
 9588    switch (src0->type) {
 9589        case GGML_TYPE_F32:
 9590            {
 9591                ggml_compute_forward_win_unpart_f32(params, dst);
 9592            } break;
 9593        default:
 9594            {
 9595                GGML_ABORT("fatal error");
 9596            }
 9597    }
 9598}
 9599
 9600//gmml_compute_forward_unary
 9601
 9602void ggml_compute_forward_unary(
 9603        const ggml_compute_params * params,
 9604        ggml_tensor * dst) {
 9605
 9606    const ggml_unary_op op = ggml_get_unary_op(dst);
 9607
 9608    switch (op) {
 9609        case GGML_UNARY_OP_ABS:
 9610            {
 9611                ggml_compute_forward_abs(params, dst);
 9612            } break;
 9613        case GGML_UNARY_OP_SGN:
 9614            {
 9615                ggml_compute_forward_sgn(params, dst);
 9616            } break;
 9617        case GGML_UNARY_OP_NEG:
 9618            {
 9619                ggml_compute_forward_neg(params, dst);
 9620            } break;
 9621        case GGML_UNARY_OP_STEP:
 9622            {
 9623                ggml_compute_forward_step(params, dst);
 9624            } break;
 9625        case GGML_UNARY_OP_TANH:
 9626            {
 9627                ggml_compute_forward_tanh(params, dst);
 9628            } break;
 9629        case GGML_UNARY_OP_ELU:
 9630            {
 9631                ggml_compute_forward_elu(params, dst);
 9632            } break;
 9633        case GGML_UNARY_OP_RELU:
 9634            {
 9635                ggml_compute_forward_relu(params, dst);
 9636            } break;
 9637        case GGML_UNARY_OP_SIGMOID:
 9638            {
 9639                ggml_compute_forward_sigmoid(params, dst);
 9640            } break;
 9641        case GGML_UNARY_OP_GELU:
 9642            {
 9643                ggml_compute_forward_gelu(params, dst);
 9644            } break;
 9645        case GGML_UNARY_OP_GELU_ERF:
 9646            {
 9647                ggml_compute_forward_gelu_erf(params, dst);
 9648            } break;
 9649        case GGML_UNARY_OP_GELU_QUICK:
 9650            {
 9651                ggml_compute_forward_gelu_quick(params, dst);
 9652            } break;
 9653        case GGML_UNARY_OP_SILU:
 9654            {
 9655                ggml_compute_forward_silu(params, dst);
 9656            } break;
 9657        case GGML_UNARY_OP_HARDSWISH:
 9658            {
 9659                ggml_compute_forward_hardswish(params, dst);
 9660            } break;
 9661        case GGML_UNARY_OP_HARDSIGMOID:
 9662            {
 9663                ggml_compute_forward_hardsigmoid(params, dst);
 9664            } break;
 9665        case GGML_UNARY_OP_EXP:
 9666            {
 9667                ggml_compute_forward_exp(params, dst);
 9668            } break;
 9669        case GGML_UNARY_OP_FLOOR:
 9670            {
 9671                ggml_compute_forward_floor(params, dst);
 9672            } break;
 9673        case GGML_UNARY_OP_CEIL:
 9674            {
 9675                ggml_compute_forward_ceil(params, dst);
 9676            } break;
 9677        case GGML_UNARY_OP_ROUND:
 9678            {
 9679                ggml_compute_forward_round(params, dst);
 9680            } break;
 9681        case GGML_UNARY_OP_TRUNC:
 9682            {
 9683                ggml_compute_forward_trunc(params, dst);
 9684            } break;
 9685        case GGML_UNARY_OP_XIELU:
 9686            {
 9687                ggml_compute_forward_xielu(params, dst);
 9688            } break;
 9689        case GGML_UNARY_OP_EXPM1:
 9690            {
 9691                ggml_compute_forward_expm1(params, dst);
 9692            } break;
 9693        case GGML_UNARY_OP_SOFTPLUS:
 9694            {
 9695                ggml_compute_forward_softplus(params, dst);
 9696            } break;
 9697        default:
 9698            {
 9699                GGML_ABORT("fatal error");
 9700            }
 9701    }
 9702}
 9703
 9704//ggml_compute_forward_glu
 9705
 9706void ggml_compute_forward_glu(
 9707        const ggml_compute_params * params,
 9708        ggml_tensor * dst) {
 9709
 9710    const ggml_glu_op op = ggml_get_glu_op(dst);
 9711
 9712    switch (op) {
 9713        case GGML_GLU_OP_REGLU:
 9714            {
 9715                ggml_compute_forward_reglu(params, dst);
 9716            } break;
 9717        case GGML_GLU_OP_GEGLU:
 9718            {
 9719                ggml_compute_forward_geglu(params, dst);
 9720            } break;
 9721        case GGML_GLU_OP_SWIGLU:
 9722            {
 9723                ggml_compute_forward_swiglu(params, dst);
 9724            } break;
 9725        case GGML_GLU_OP_SWIGLU_OAI:
 9726            {
 9727                ggml_compute_forward_swiglu_oai(params, dst);
 9728            } break;
 9729        case GGML_GLU_OP_GEGLU_ERF:
 9730            {
 9731                ggml_compute_forward_geglu_erf(params, dst);
 9732            } break;
 9733        case GGML_GLU_OP_GEGLU_QUICK:
 9734            {
 9735                ggml_compute_forward_geglu_quick(params, dst);
 9736            } break;
 9737        default:
 9738            {
 9739                GGML_ABORT("fatal error");
 9740            }
 9741    }
 9742}
 9743
 9744// ggml_compute_forward_get_rel_pos
 9745
 9746static void ggml_compute_forward_get_rel_pos_f16(
 9747        const ggml_compute_params * params,
 9748        ggml_tensor * dst) {
 9749    GGML_UNUSED(params);
 9750
 9751    const ggml_tensor * src0 = dst->src[0];
 9752
 9753    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L292-L322
 9754
 9755    GGML_TENSOR_UNARY_OP_LOCALS
 9756
 9757    const int64_t w = ne1;
 9758
 9759    ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data;
 9760    ggml_fp16_t * dst_data  = (ggml_fp16_t *) dst->data;
 9761
 9762    for (int64_t i2 = 0; i2 < ne2; ++i2) {
 9763        for (int64_t i1 = 0; i1 < ne1; ++i1) {
 9764            const int64_t pos = (w - i1 - 1) + i2;
 9765            for (int64_t i0 = 0; i0 < ne0; ++i0) {
 9766                dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0];
 9767            }
 9768        }
 9769    }
 9770}
 9771
 9772void ggml_compute_forward_get_rel_pos(
 9773        const ggml_compute_params * params,
 9774        ggml_tensor * dst) {
 9775
 9776    const ggml_tensor * src0 = dst->src[0];
 9777
 9778    switch (src0->type) {
 9779        case GGML_TYPE_F16:
 9780        case GGML_TYPE_BF16:
 9781            {
 9782                ggml_compute_forward_get_rel_pos_f16(params, dst);
 9783            } break;
 9784        default:
 9785            {
 9786                GGML_ABORT("fatal error");
 9787            }
 9788    }
 9789}
 9790
 9791// ggml_compute_forward_add_rel_pos
 9792
 9793static void ggml_compute_forward_add_rel_pos_f32(
 9794        const ggml_compute_params * params,
 9795        ggml_tensor * dst) {
 9796
 9797    const ggml_tensor * src0 = dst->src[0];
 9798    const ggml_tensor * src1 = dst->src[1];
 9799    const ggml_tensor * src2 = dst->src[2];
 9800
 9801    const bool inplace = (bool) ((int32_t *) dst->op_params)[0];
 9802    if (!inplace) {
 9803        if (params->ith == 0) {
 9804            memcpy((char *) dst->data, (char *) src0->data, ggml_nbytes(dst));
 9805        }
 9806        ggml_barrier(params->threadpool);
 9807    }
 9808    // ref: https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py#L357-L359
 9809
 9810    float * src1_data = (float *) src1->data;
 9811    float * src2_data = (float *) src2->data;
 9812    float * dst_data  = (float *) dst->data;
 9813
 9814    const int64_t ne10 = src1->ne[0];
 9815    const int64_t ne11 = src1->ne[1];
 9816    const int64_t ne12 = src1->ne[2];
 9817    const int64_t ne13 = src1->ne[3];
 9818
 9819    const int ith = params->ith;
 9820    const int nth = params->nth;
 9821
 9822    // total patches in dst
 9823    const int np = ne13;
 9824
 9825    // patches per thread
 9826    const int dp = (np + nth - 1)/nth;
 9827
 9828    // patch range for this thread
 9829    const int ip0 = dp*ith;
 9830    const int ip1 = MIN(ip0 + dp, np);
 9831
 9832    for (int64_t i13 = ip0; i13 < ip1; ++i13) {
 9833        for (int64_t i12 = 0; i12 < ne12; ++i12) {
 9834            for (int64_t i11 = 0; i11 < ne11; ++i11) {
 9835                const int64_t jp1 = i13*ne12*ne11*ne10 + i12*ne11*ne10 + i11*ne10;
 9836                for (int64_t i10 = 0; i10 < ne10; ++i10) {
 9837                    const int64_t jp0  = jp1 + i10;
 9838                    const float src1_e = src1_data[jp0];
 9839                    const float src2_e = src2_data[jp0];
 9840
 9841                    const int64_t jdh = jp0 * ne10;
 9842                    const int64_t jdw = jdh - (ne10 - 1) * i10;
 9843
 9844                    for (int64_t j = 0; j < ne10; ++j) {
 9845                        dst_data[jdh + j     ] += src2_e;
 9846                        dst_data[jdw + j*ne10] += src1_e;
 9847                    }
 9848                }
 9849            }
 9850        }
 9851    }
 9852}
 9853
 9854void ggml_compute_forward_add_rel_pos(
 9855        const ggml_compute_params * params,
 9856        ggml_tensor * dst) {
 9857
 9858    const ggml_tensor * src0 = dst->src[0];
 9859
 9860    switch (src0->type) {
 9861        case GGML_TYPE_F32:
 9862            {
 9863                ggml_compute_forward_add_rel_pos_f32(params, dst);
 9864            } break;
 9865        default:
 9866            {
 9867                GGML_ABORT("fatal error");
 9868            }
 9869    }
 9870}
 9871
 9872// ggml_compute_forward_rwkv_wkv6
 9873
 9874static void ggml_compute_forward_rwkv_wkv6_f32(
 9875        const ggml_compute_params * params,
 9876        ggml_tensor * dst) {
 9877    const int64_t T = dst->src[1]->ne[2];
 9878    const int64_t C = dst->ne[0];
 9879    const int64_t HEADS = dst->src[1]->ne[1];
 9880    const int64_t n_seqs = dst->src[5]->ne[1];
 9881    const int64_t head_size = C / HEADS;
 9882
 9883    float * dst_data = (float *) dst->data;
 9884    float * state = ((float *) dst->data) + C * T;
 9885
 9886    const int ith = params->ith;
 9887    const int nth = params->nth;
 9888
 9889    if (ith >= HEADS) {
 9890        return;
 9891    }
 9892
 9893    const int h_start = (HEADS * ith) / nth;
 9894    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
 9895                (HEADS * (ith + 1)) / nth : HEADS;
 9896
 9897    float * k =          (float *) dst->src[0]->data;
 9898    float * v =          (float *) dst->src[1]->data;
 9899    float * r =          (float *) dst->src[2]->data;
 9900    float * time_faaaa = (float *) dst->src[3]->data;
 9901    float * time_decay = (float *) dst->src[4]->data;
 9902
 9903    size_t t_stride = HEADS * head_size; // Same to C
 9904
 9905    size_t h_stride = C / HEADS;
 9906    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
 9907    size_t h_stride_2d = head_size * head_size;
 9908
 9909    if (ith == 0) {
 9910        memset(dst_data, 0, T * C * sizeof(float));
 9911    }
 9912    ggml_barrier(params->threadpool);
 9913
 9914
 9915    #if defined(__AVX__) && !defined(__AVX512F__)
 9916        #define GGML_F32X GGML_F32x8
 9917        #define GGML_F32X_SET1 GGML_F32x8_SET1
 9918        #define GGML_F32X_LOAD GGML_F32x8_LOAD
 9919        #define GGML_F32X_STORE GGML_F32x8_STORE
 9920        #define GGML_F32X_MUL GGML_F32x8_MUL
 9921        #define GGML_F32X_FMA GGML_F32x8_FMA
 9922        #define WKV_VECTOR_SIZE 8
 9923    #elif defined(__AVX512F__)
 9924        #define GGML_F32X GGML_F32x16
 9925        #define GGML_F32X_SET1 GGML_F32x16_SET1
 9926        #define GGML_F32X_LOAD GGML_F32x16_LOAD
 9927        #define GGML_F32X_STORE GGML_F32x16_STORE
 9928        #define GGML_F32X_MUL GGML_F32x16_MUL
 9929        #define GGML_F32X_FMA GGML_F32x16_FMA
 9930        #define WKV_VECTOR_SIZE 16
 9931    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
 9932        #define GGML_F32X GGML_F32xt
 9933        #define GGML_F32X_SET1 GGML_F32xt_SET1
 9934        #define GGML_F32X_LOAD GGML_F32xt_LOAD
 9935        #define GGML_F32X_STORE GGML_F32xt_STORE
 9936        #define GGML_F32X_MUL GGML_F32xt_MUL
 9937        #define GGML_F32X_FMA GGML_F32xt_FMA
 9938        #define WKV_VECTOR_SIZE 8
 9939    #elif defined(__ARM_NEON) && defined(__aarch64__)
 9940        #define GGML_F32X GGML_F32x4
 9941        #define GGML_F32X_SET1 GGML_F32x4_SET1
 9942        #define GGML_F32X_LOAD GGML_F32x4_LOAD
 9943        #define GGML_F32X_STORE GGML_F32x4_STORE
 9944        #define GGML_F32X_MUL GGML_F32x4_MUL
 9945        #define GGML_F32X_FMA GGML_F32x4_FMA
 9946        #define WKV_VECTOR_SIZE 4
 9947    #endif
 9948
 9949    #ifdef WKV_VECTOR_SIZE
 9950        int wkv_vector_size;
 9951        #if defined(__ARM_FEATURE_SVE)
 9952            wkv_vector_size = svcntw();
 9953        #else
 9954            wkv_vector_size = WKV_VECTOR_SIZE;
 9955        #endif
 9956        const int64_t vec_count = head_size / wkv_vector_size;
 9957
 9958        for (int64_t t = 0; t < T; t++) {
 9959            size_t t_offset = t * t_stride;
 9960            size_t state_offset = head_size * C * (t / (T / n_seqs));
 9961            float * state_cur = state + state_offset;
 9962            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
 9963
 9964            for (int64_t h = h_start; h < h_end; h++) {
 9965                size_t h_offset = h * h_stride;
 9966                size_t t_h_offset = t_offset + h_offset;
 9967                size_t h_2d_offset = h * h_stride_2d;
 9968
 9969                for (int64_t i = 0; i < head_size; i++) {
 9970                    size_t t_h_i_offset = t_h_offset + i;
 9971                    size_t h_i_offset = h_offset + i;
 9972                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
 9973
 9974                    float k_val = k[t_h_i_offset];
 9975                    float r_val = r[t_h_i_offset];
 9976                    float time_faaaa_val = time_faaaa[h_i_offset];
 9977                    float time_decay_val = time_decay[t_h_i_offset];
 9978
 9979                    // Broadcast scalar values to vectors
 9980                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
 9981                    GGML_F32X r_vec = GGML_F32X_SET1(r_val);
 9982                    GGML_F32X time_faaaa_vec = GGML_F32X_SET1(time_faaaa_val);
 9983                    GGML_F32X time_decay_vec = GGML_F32X_SET1(time_decay_val);
 9984
 9985                    for (int64_t j = 0; j < vec_count; j++) {
 9986                        size_t base_j = j * wkv_vector_size;
 9987                        size_t t_h_j_offset = t_h_offset + base_j;
 9988                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
 9989
 9990                        // Load x elements at once
 9991                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
 9992                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
 9993                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
 9994
 9995                        // Compute kv = v * k
 9996                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
 9997
 9998                        // Compute temp = kv * time_faaaa + prev_state
 9999                        GGML_F32X temp_vec = GGML_F32X_FMA(prev_state_vec, kv_vec, time_faaaa_vec);
10000
10001                        // Update dst: dst += temp * r
10002                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, r_vec);
10003                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
10004
10005                        // Update state: state = prev_state * time_decay + kv
10006                        GGML_F32X new_state_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, time_decay_vec);
10007                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], new_state_vec);
10008                    }
10009
10010                    // Handle remaining elements, this will not be used.
10011                    for (int64_t j = vec_count * wkv_vector_size; j < head_size; j++) {
10012                        size_t t_h_j_offset = t_h_offset + j;
10013                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10014                        float v_val = v[t_h_j_offset];
10015                        float kv_val = v_val * k_val;
10016                        float prev_state_val = state_prev[h_2d_i_j_offset];
10017                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
10018                        dst_data[t_h_j_offset] += temp_val * r_val;
10019                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
10020                    }
10021                }
10022            }
10023        }
10024
10025    #else
10026        // basically fused operations:
10027        // dst = r @ (time_faaaa * (k @ v) + state),
10028        // state = time_decay * state + (k @ v),
10029        // recursive through each token
10030        for (int64_t t = 0; t < T; t++) {
10031            size_t t_offset = t * t_stride;
10032            size_t state_offset = head_size * C * (t / (T / n_seqs));
10033            float * state_cur = state + state_offset;
10034            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[5]->data + state_offset;
10035
10036            for (int64_t h = h_start; h < h_end; h++) {
10037                size_t h_offset = h * h_stride;
10038                size_t t_h_offset = t_offset + h_offset;
10039                size_t h_2d_offset = h * h_stride_2d;
10040
10041                for (int64_t i = 0; i < head_size; i++) {
10042                    size_t t_h_i_offset = t_h_offset + i;
10043                    size_t h_i_offset = h_offset + i;
10044                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10045
10046                    float k_val = k[t_h_i_offset];
10047                    float r_val = r[t_h_i_offset];
10048                    float time_faaaa_val = time_faaaa[h_i_offset];
10049                    // RWKV v6: different time_decay for each token.
10050                    float time_decay_val = time_decay[t_h_i_offset];
10051
10052                    for (int64_t j = 0; j < head_size; j++) {
10053                        size_t t_h_j_offset = t_h_offset + j;
10054                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10055
10056                        float v_val = v[t_h_j_offset];
10057                        float kv_val = v_val * k_val;
10058                        float prev_state_val = state_prev[h_2d_i_j_offset];
10059                        float temp_val = kv_val * time_faaaa_val + prev_state_val;
10060                        dst_data[t_h_j_offset] += temp_val * r_val;
10061                        state_cur[h_2d_i_j_offset] = prev_state_val * time_decay_val + kv_val;
10062                    }
10063                }
10064            }
10065        }
10066    #endif
10067}
10068
10069
10070void ggml_compute_forward_rwkv_wkv6(
10071        const ggml_compute_params * params,
10072        ggml_tensor * dst) {
10073
10074    const ggml_tensor * src0 = dst->src[0];
10075
10076    switch (src0->type) {
10077        case GGML_TYPE_F32:
10078            {
10079                ggml_compute_forward_rwkv_wkv6_f32(params, dst);
10080            } break;
10081        default:
10082            {
10083                GGML_ABORT("fatal error");
10084            }
10085    }
10086}
10087
10088// ggml_compute_forward_gla
10089
10090static void ggml_compute_forward_gla_f32(
10091        const ggml_compute_params * params,
10092        ggml_tensor * dst) {
10093    const int64_t T = dst->src[1]->ne[2];
10094    const int64_t C = dst->ne[0];
10095    const int64_t HEADS = dst->src[1]->ne[1];
10096    const int64_t n_seqs = dst->src[4]->ne[1];
10097    const int64_t head_size = C / HEADS;
10098    const float scale = ggml_get_op_params_f32(dst, 0);
10099
10100    float * dst_data = (float *) dst->data;
10101    float * state = ((float *) dst->data) + C * T;
10102
10103    const int ith = params->ith;
10104    const int nth = params->nth;
10105
10106    if (ith >= HEADS) {
10107        return;
10108    }
10109
10110    const int h_start = (HEADS * ith) / nth;
10111    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10112                (HEADS * (ith + 1)) / nth : HEADS;
10113
10114    float * k = (float *) dst->src[0]->data;
10115    float * v = (float *) dst->src[1]->data;
10116    float * q = (float *) dst->src[2]->data;
10117    float * g = (float *) dst->src[3]->data;
10118
10119    size_t t_stride = HEADS * head_size; // Same to C
10120
10121    size_t h_stride = C / HEADS;
10122    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10123    size_t h_stride_2d = head_size * head_size;
10124
10125    if (ith == 0) {
10126        memset(dst_data, 0, T * C * sizeof(float));
10127    }
10128    ggml_barrier(params->threadpool);
10129
10130
10131    #if defined(__AVX__) && !defined(__AVX512F__)
10132        #define GGML_F32X GGML_F32x8
10133        #define GGML_F32X_SET1 GGML_F32x8_SET1
10134        #define GGML_F32X_LOAD GGML_F32x8_LOAD
10135        #define GGML_F32X_STORE GGML_F32x8_STORE
10136        #define GGML_F32X_MUL GGML_F32x8_MUL
10137        #define GGML_F32X_FMA GGML_F32x8_FMA
10138        #define GLA_VECTOR_SIZE 8
10139    #elif defined(__AVX512F__)
10140        #define GGML_F32X GGML_F32x16
10141        #define GGML_F32X_SET1 GGML_F32x16_SET1
10142        #define GGML_F32X_LOAD GGML_F32x16_LOAD
10143        #define GGML_F32X_STORE GGML_F32x16_STORE
10144        #define GGML_F32X_MUL GGML_F32x16_MUL
10145        #define GGML_F32X_FMA GGML_F32x16_FMA
10146        #define GLA_VECTOR_SIZE 16
10147    #elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
10148        #define GGML_F32X GGML_F32xt
10149        #define GGML_F32X_SET1 GGML_F32xt_SET1
10150        #define GGML_F32X_LOAD GGML_F32xt_LOAD
10151        #define GGML_F32X_STORE GGML_F32xt_STORE
10152        #define GGML_F32X_MUL GGML_F32xt_MUL
10153        #define GGML_F32X_FMA GGML_F32xt_FMA
10154        #define GLA_VECTOR_SIZE 8
10155    #elif defined(__ARM_NEON) && defined(__aarch64__)
10156        #define GGML_F32X GGML_F32x4
10157        #define GGML_F32X_SET1 GGML_F32x4_SET1
10158        #define GGML_F32X_LOAD GGML_F32x4_LOAD
10159        #define GGML_F32X_STORE GGML_F32x4_STORE
10160        #define GGML_F32X_MUL GGML_F32x4_MUL
10161        #define GGML_F32X_FMA GGML_F32x4_FMA
10162        #define GLA_VECTOR_SIZE 4
10163    #endif
10164
10165    #ifdef GLA_VECTOR_SIZE
10166        int gla_vector_size;
10167        #if defined(__ARM_FEATURE_SVE)
10168            gla_vector_size = svcntw();
10169        #else
10170            gla_vector_size = GLA_VECTOR_SIZE;
10171        #endif
10172        const int64_t vec_count = head_size / gla_vector_size;
10173
10174        for (int64_t t = 0; t < T; t++) {
10175            size_t t_offset = t * t_stride;
10176            size_t state_offset = head_size * C * (t / (T / n_seqs));
10177            float * state_cur = state + state_offset;
10178            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
10179
10180            for (int64_t h = h_start; h < h_end; h++) {
10181                size_t h_offset = h * h_stride;
10182                size_t t_h_offset = t_offset + h_offset;
10183                size_t h_2d_offset = h * h_stride_2d;
10184
10185                for (int64_t i = 0; i < head_size; i++) {
10186                    size_t t_h_i_offset = t_h_offset + i;
10187                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10188
10189                    float k_val = k[t_h_i_offset];
10190                    float q_val = q[t_h_i_offset] * scale;
10191                    float g_val = g[t_h_i_offset];
10192
10193                    // Broadcast scalar values to vectors
10194                    GGML_F32X k_vec = GGML_F32X_SET1(k_val);
10195                    GGML_F32X q_vec = GGML_F32X_SET1(q_val);
10196                    GGML_F32X g_vec = GGML_F32X_SET1(g_val);
10197
10198                    for (int64_t j = 0; j < vec_count; j++) {
10199                        size_t base_j = j * gla_vector_size;
10200                        size_t t_h_j_offset = t_h_offset + base_j;
10201                        size_t h_2d_i_j_offset = h_2d_i_offset + base_j;
10202
10203                        // Load x elements at once
10204                        GGML_F32X v_vec = GGML_F32X_LOAD(&v[t_h_j_offset]);
10205                        GGML_F32X prev_state_vec = GGML_F32X_LOAD(&state_prev[h_2d_i_j_offset]);
10206                        GGML_F32X dst_vec = GGML_F32X_LOAD(&dst_data[t_h_j_offset]);
10207
10208                        // Compute kv = v * k
10209                        GGML_F32X kv_vec = GGML_F32X_MUL(v_vec, k_vec);
10210
10211                        // Compute temp = prev_state * g + kv
10212                        GGML_F32X temp_vec = GGML_F32X_FMA(kv_vec, prev_state_vec, g_vec);
10213
10214                        // Update dst: dst += temp * q
10215                        dst_vec = GGML_F32X_FMA(dst_vec, temp_vec, q_vec);
10216                        GGML_F32X_STORE(&dst_data[t_h_j_offset], dst_vec);
10217
10218                        // Update state
10219                        GGML_F32X_STORE(&state_cur[h_2d_i_j_offset], temp_vec);
10220                    }
10221
10222                    // Handle remaining elements, this will not be used.
10223                    for (int64_t j = vec_count * gla_vector_size; j < head_size; j++) {
10224                        size_t t_h_j_offset = t_h_offset + j;
10225                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10226                        float v_val = v[t_h_j_offset];
10227                        float kv_val = v_val * k_val;
10228                        float prev_state_val = state_prev[h_2d_i_j_offset];
10229                        float temp_val = kv_val + prev_state_val * g_val;
10230                        dst_data[t_h_j_offset] += temp_val * q_val;
10231                        state_cur[h_2d_i_j_offset] = temp_val;
10232                    }
10233                }
10234            }
10235        }
10236
10237    #else
10238        for (int64_t t = 0; t < T; t++) {
10239            size_t t_offset = t * t_stride;
10240            size_t state_offset = head_size * C * (t / (T / n_seqs));
10241            float * state_cur = state + state_offset;
10242            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[4]->data + state_offset;
10243
10244            for (int64_t h = h_start; h < h_end; h++) {
10245                size_t h_offset = h * h_stride;
10246                size_t t_h_offset = t_offset + h_offset;
10247                size_t h_2d_offset = h * h_stride_2d;
10248
10249                for (int64_t i = 0; i < head_size; i++) {
10250                    size_t t_h_i_offset = t_h_offset + i;
10251                    size_t h_2d_i_offset = h_2d_offset + i * h_stride;
10252
10253                    float k_val = k[t_h_i_offset];
10254                    float q_val = q[t_h_i_offset] * scale;
10255                    float g_val = g[t_h_i_offset];
10256
10257                    for (int64_t j = 0; j < head_size; j++) {
10258                        size_t t_h_j_offset = t_h_offset + j;
10259                        size_t h_2d_i_j_offset = h_2d_i_offset + j;
10260
10261                        float v_val = v[t_h_j_offset];
10262                        float kv_val = v_val * k_val;
10263                        float prev_state_val = state_prev[h_2d_i_j_offset];
10264                        float temp_val = prev_state_val * g_val + kv_val;
10265                        dst_data[t_h_j_offset] += temp_val * q_val;
10266                        state_cur[h_2d_i_j_offset] = temp_val;
10267                    }
10268                }
10269            }
10270        }
10271    #endif
10272}
10273
10274
10275void ggml_compute_forward_gla(
10276        const ggml_compute_params * params,
10277        ggml_tensor * dst) {
10278
10279    const ggml_tensor * src0 = dst->src[0];
10280
10281    switch (src0->type) {
10282        case GGML_TYPE_F32:
10283            {
10284                ggml_compute_forward_gla_f32(params, dst);
10285            } break;
10286        default:
10287            {
10288                GGML_ABORT("fatal error");
10289            }
10290    }
10291}
10292
10293static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10294    const struct ggml_tensor * src0 = dst->src[0];  // A (lower triangular)
10295    const struct ggml_tensor * src1 = dst->src[1];  // B (RHS)
10296
10297    GGML_TENSOR_BINARY_OP_LOCALS;
10298
10299    GGML_ASSERT(src0->type == GGML_TYPE_F32);
10300    GGML_ASSERT(src1->type == GGML_TYPE_F32);
10301    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
10302
10303    GGML_ASSERT(ne00 == ne01); // A must be square
10304    GGML_ASSERT(ne0  == ne10); // solution cols == B cols
10305    GGML_ASSERT(ne1  == ne11); // solution rows == B rows
10306
10307    GGML_ASSERT(ne02 == ne12 && ne12 == ne2);
10308    GGML_ASSERT(ne03 == ne13 && ne13 == ne3);
10309
10310    const int ith = params->ith;
10311    const int nth = params->nth;
10312
10313    const int64_t k = ne10;   // number of RHS columns
10314    const int64_t n = ne11;   // A is nร—n
10315    const int64_t nr = ne02 * ne03 * k; // we're parallelizing on columns here, so seq x token x column will be the unit
10316
10317    // chunks per thread
10318    const int64_t dr = (nr + nth - 1)/nth;
10319
10320    // chunk range for this thread
10321    const int64_t ir0 = dr*ith;
10322    const int64_t ir1 = MIN(ir0 + dr, nr);
10323
10324    const float * A = (const float *) src0->data;  // [n, n, B1, B2]
10325    const float * B = (const float *) src1->data;  // [n, k, B1, B2]
10326          float * X = (      float *) dst->data;   // [n, k, B1, B2]
10327
10328    for (int64_t ir = ir0; ir < ir1; ++ir) {
10329        const int64_t i03 = ir/(ne02*k);
10330        const int64_t i02 = (ir - i03*ne02*k)/k;
10331        const int64_t i01 = (ir - i03*ne02*k - i02*k);
10332
10333        const float * A_batch = A + i02 * nb02 / sizeof(float) + i03 * nb03 / sizeof(float);
10334        const float * B_batch = B + i02 * nb12 / sizeof(float) + i03 * nb13 / sizeof(float);
10335
10336        float * X_batch = X + i02 * nb2 / sizeof(float) + i03 * nb3 / sizeof(float);
10337
10338        for (int64_t i00 = 0; i00 < n; ++i00) {
10339            float sum = 0.0f;
10340            for (int64_t t = 0; t < i00; ++t) {
10341                sum += A_batch[i00 * n + t] * X_batch[t * k + i01];
10342            }
10343
10344            const float diag = A_batch[i00 * n + i00];
10345            assert(diag != 0.0f && "Zero diagonal in triangular matrix");
10346
10347            X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
10348        }
10349    }
10350}
10351
10352void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst) {
10353    const ggml_tensor * src0 = dst->src[0];
10354    const ggml_tensor * src1 = dst->src[1];
10355
10356    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
10357        ggml_compute_forward_solve_tri_f32(params, dst);
10358    } else {
10359        GGML_ABORT("fatal error");
10360    }
10361}
10362
10363// ggml_compute_forward_rwkv_wkv7
10364
10365static void ggml_compute_forward_rwkv_wkv7_f32(
10366        const ggml_compute_params * params,
10367        ggml_tensor * dst) {
10368    const int64_t T = dst->src[1]->ne[2];
10369    const int64_t C = dst->ne[0];
10370    const int64_t HEADS = dst->src[1]->ne[1];
10371    const int64_t n_seqs = dst->src[6]->ne[1];
10372    const int64_t head_size = C / HEADS;
10373
10374    float * dst_data = (float *) dst->data;
10375    float * state = ((float *) dst->data) + C * T;
10376
10377    const int ith = params->ith;
10378    const int nth = params->nth;
10379
10380    if (ith >= HEADS) {
10381        return;
10382    }
10383
10384    const int h_start = (HEADS * ith) / nth;
10385    const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ?
10386                (HEADS * (ith + 1)) / nth : HEADS;
10387
10388    float * r = (float *) dst->src[0]->data;
10389    float * w = (float *) dst->src[1]->data;
10390    float * k = (float *) dst->src[2]->data;
10391    float * v = (float *) dst->src[3]->data;
10392    float * a = (float *) dst->src[4]->data;
10393    float * b = (float *) dst->src[5]->data;
10394
10395    int64_t t_stride = HEADS * head_size; // Same to C
10396
10397    int64_t h_stride = C / HEADS;
10398    GGML_ASSERT(C % HEADS == 0); // C must be divisible by HEADS
10399    int64_t h_stride_2d = head_size * head_size;
10400
10401    #if defined(GGML_SIMD)
10402        #if defined(__ARM_FEATURE_SVE) || defined(__riscv_v_intrinsic)
10403            // scalar Route to scalar implementation       //TODO: Write SVE code and RVV code
10404            for (int64_t t = 0; t < T; t++) {
10405                int64_t t_offset = t * t_stride;
10406                int64_t state_offset = head_size * C * (t / (T / n_seqs));
10407                float * state_cur = state + state_offset;
10408                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10409
10410                for (int64_t h = h_start; h < h_end; h++) {
10411                    int64_t h_offset = h * h_stride;
10412                    int64_t t_h_offset = t_offset + h_offset;
10413                    int64_t h_2d_offset = h * h_stride_2d;
10414
10415                    for (int64_t i = 0; i < head_size; i++) {
10416                        int64_t t_h_i_offset = t_h_offset + i;
10417                        int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10418
10419                        float v_val = v[t_h_i_offset];
10420
10421                        float sa = 0, result = 0;
10422                        for (int64_t j = 0; j < head_size; j++) {
10423                            sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10424                        }
10425
10426                        for (int64_t j = 0; j < head_size; j++) {
10427                            int64_t t_h_j_offset = t_h_offset + j;
10428                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10429
10430                            float r_val = r[t_h_j_offset];
10431                            float w_val = w[t_h_j_offset];
10432                            float k_val = k[t_h_j_offset];
10433                            float b_val = b[t_h_j_offset];
10434                            float kv_val = v_val * k_val;
10435                            float prev_state_val = state_prev[h_2d_i_j_offset];
10436                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10437                            result += state_cur[h_2d_i_j_offset] * r_val;
10438                        }
10439                        dst_data[t_h_i_offset] = result;
10440                    }
10441                }
10442            }
10443        #else
10444            for (int64_t t = 0; t < T; t++) {
10445                int64_t t_offset = t * t_stride;
10446                int64_t state_offset = head_size * C * (t / (T / n_seqs));
10447                float * state_cur = state + state_offset;
10448                float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10449
10450                for (int64_t h = h_start; h < h_end; h++) {
10451                    int64_t h_offset = h * h_stride;
10452                    int64_t t_h_offset = t_offset + h_offset;
10453                    int64_t h_2d_offset = h * h_stride_2d;
10454
10455                    for (int64_t ii = 0; ii < head_size; ii++) {
10456                        int64_t t_h_i_offset = t_h_offset + ii;
10457                        int64_t h_2d_i_offset = h_2d_offset + ii * h_stride;
10458
10459                        GGML_F32_VEC v_vec = GGML_F32_VEC_SET1(v[t_h_i_offset]);
10460
10461                        float sa = 0;
10462                        {
10463                            GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10464                            GGML_F32_VEC ax[GGML_F32_ARR];
10465                            GGML_F32_VEC ay[GGML_F32_ARR];
10466                            for (int64_t j = 0; j < head_size; j += GGML_F32_STEP) {
10467                                for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10468                                    ax[kk] = GGML_F32_VEC_LOAD(&a[t_h_offset + j + kk * GGML_F32_EPR]);
10469                                    ay[kk] = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_offset + j + kk * GGML_F32_EPR]);
10470                                    sum[kk] = GGML_F32_VEC_FMA(sum[kk], ax[kk], ay[kk]);
10471                                }
10472                            }
10473                            GGML_F32_VEC_REDUCE(sa, sum);
10474                        }
10475
10476                        GGML_F32_VEC sa_vec = GGML_F32_VEC_SET1(sa);
10477
10478                        int64_t j = 0;
10479                        GGML_F32_VEC result_vec[GGML_F32_ARR] = { GGML_F32_VEC_ZERO };
10480                        for (; j < head_size; j += GGML_F32_STEP) {
10481                            for (int64_t kk = 0; kk < GGML_F32_ARR; kk++) {
10482                                int64_t t_h_j_offset = t_h_offset + j + kk * GGML_F32_EPR;
10483                                int64_t h_2d_i_j_offset = h_2d_i_offset + j + kk * GGML_F32_EPR;
10484
10485                                GGML_F32_VEC r_vec = GGML_F32_VEC_LOAD(&r[t_h_j_offset]);
10486                                GGML_F32_VEC w_vec = GGML_F32_VEC_LOAD(&w[t_h_j_offset]);
10487                                GGML_F32_VEC k_vec = GGML_F32_VEC_LOAD(&k[t_h_j_offset]);
10488                                GGML_F32_VEC b_vec = GGML_F32_VEC_LOAD(&b[t_h_j_offset]);
10489
10490                                k_vec = GGML_F32_VEC_MUL(v_vec, k_vec);
10491
10492                                GGML_F32_VEC state_vec = GGML_F32_VEC_LOAD(&state_prev[h_2d_i_j_offset]);
10493                                // kv + s * decay + sa * b
10494                                state_vec = GGML_F32_VEC_FMA(k_vec, state_vec, w_vec);
10495                                state_vec = GGML_F32_VEC_FMA(state_vec, sa_vec, b_vec);
10496                                GGML_F32_VEC_STORE(&state_cur[h_2d_i_j_offset], state_vec);
10497
10498                                result_vec[kk] = GGML_F32_VEC_FMA(result_vec[kk], state_vec, r_vec);
10499                            }
10500                        }
10501                        GGML_F32_VEC_REDUCE(dst_data[t_h_i_offset], result_vec);
10502
10503                        // There shouldn't be left-overs though.
10504                        for (; j < head_size; j++) {
10505                            int64_t t_h_j_offset = t_h_offset + j;
10506                            int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10507
10508                            float r_val = r[t_h_j_offset];
10509                            float w_val = w[t_h_j_offset];
10510                            float k_val = k[t_h_j_offset];
10511                            float b_val = b[t_h_j_offset];
10512                            float kv_val = v[t_h_i_offset] * k_val;
10513
10514                            float prev_state_val = state_prev[h_2d_i_j_offset];
10515                            state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10516                            dst_data[t_h_i_offset] += state_cur[h_2d_i_j_offset] * r_val;
10517                        }
10518                    }
10519                }
10520            }
10521        #endif
10522    #else
10523        for (int64_t t = 0; t < T; t++) {
10524            int64_t t_offset = t * t_stride;
10525            int64_t state_offset = head_size * C * (t / (T / n_seqs));
10526            float * state_cur = state + state_offset;
10527            float * state_prev = t % (T / n_seqs) ? state_cur : (float*)dst->src[6]->data + state_offset;
10528
10529            for (int64_t h = h_start; h < h_end; h++) {
10530                int64_t h_offset = h * h_stride;
10531                int64_t t_h_offset = t_offset + h_offset;
10532                int64_t h_2d_offset = h * h_stride_2d;
10533
10534                for (int64_t i = 0; i < head_size; i++) {
10535                    int64_t t_h_i_offset = t_h_offset + i;
10536                    int64_t h_2d_i_offset = h_2d_offset + i * h_stride;
10537
10538                    float v_val = v[t_h_i_offset];
10539
10540                    float sa = 0, result = 0;
10541                    for (int64_t j = 0; j < head_size; j++) {
10542                        sa += a[t_h_offset + j] * state_prev[h_2d_i_offset + j];
10543                    }
10544
10545                    for (int64_t j = 0; j < head_size; j++) {
10546                        int64_t t_h_j_offset = t_h_offset + j;
10547                        int64_t h_2d_i_j_offset = h_2d_i_offset + j;
10548
10549                        float r_val = r[t_h_j_offset];
10550                        float w_val = w[t_h_j_offset];
10551                        float k_val = k[t_h_j_offset];
10552                        float b_val = b[t_h_j_offset];
10553                        float kv_val = v_val * k_val;
10554                        float prev_state_val = state_prev[h_2d_i_j_offset];
10555                        state_cur[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
10556                        result += state_cur[h_2d_i_j_offset] * r_val;
10557                    }
10558                    dst_data[t_h_i_offset] = result;
10559                }
10560            }
10561        }
10562    #endif
10563}
10564
10565
10566void ggml_compute_forward_rwkv_wkv7(
10567        const ggml_compute_params * params,
10568        ggml_tensor * dst) {
10569
10570    const ggml_tensor * src0 = dst->src[0];
10571
10572    switch (src0->type) {
10573        case GGML_TYPE_F32:
10574            {
10575                ggml_compute_forward_rwkv_wkv7_f32(params, dst);
10576            } break;
10577        default:
10578            {
10579                GGML_ABORT("fatal error");
10580            }
10581    }
10582}
10583
10584// ggml_compute_forward_map_custom1
10585
10586void ggml_compute_forward_map_custom1(
10587        const ggml_compute_params * params,
10588              ggml_tensor * dst) {
10589
10590    const ggml_tensor * a = dst->src[0];
10591
10592    struct ggml_map_custom1_op_params p;
10593    memcpy(&p, dst->op_params, sizeof(p));
10594
10595    p.fun(dst, a, params->ith, params->nth, p.userdata);
10596}
10597
10598// ggml_compute_forward_map_custom2
10599
10600void ggml_compute_forward_map_custom2(
10601        const ggml_compute_params * params,
10602              ggml_tensor * dst) {
10603
10604    const ggml_tensor * a = dst->src[0];
10605    const ggml_tensor * b = dst->src[1];
10606
10607    struct ggml_map_custom2_op_params p;
10608    memcpy(&p, dst->op_params, sizeof(p));
10609
10610    p.fun(dst, a, b, params->ith, params->nth, p.userdata);
10611}
10612
10613// ggml_compute_forward_map_custom3
10614
10615void ggml_compute_forward_map_custom3(
10616        const ggml_compute_params * params,
10617              ggml_tensor * dst) {
10618
10619    const ggml_tensor * a = dst->src[0];
10620    const ggml_tensor * b = dst->src[1];
10621    const ggml_tensor * c = dst->src[2];
10622
10623    struct ggml_map_custom3_op_params p;
10624    memcpy(&p, dst->op_params, sizeof(p));
10625
10626    p.fun(dst, a, b, c, params->ith, params->nth, p.userdata);
10627}
10628
10629// ggml_compute_forward_custom
10630
10631void ggml_compute_forward_custom(
10632    const struct ggml_compute_params * params,
10633          struct ggml_tensor * dst) {
10634
10635    struct ggml_custom_op_params p;
10636    memcpy(&p, dst->op_params, sizeof(p));
10637
10638    p.fun(dst, params->ith, params->nth, p.userdata);
10639}
10640
10641// ggml_compute_forward_cross_entropy_loss
10642
10643static void ggml_compute_forward_cross_entropy_loss_f32(
10644        const ggml_compute_params * params,
10645        ggml_tensor * dst) {
10646
10647    const ggml_tensor * src0 = dst->src[0];
10648    const ggml_tensor * src1 = dst->src[1];
10649
10650    GGML_ASSERT(src0->type == GGML_TYPE_F32);
10651    GGML_ASSERT(src1->type == GGML_TYPE_F32);
10652    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
10653    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
10654    GGML_ASSERT(ggml_are_same_shape(src0, src1));
10655    GGML_ASSERT(ggml_is_scalar(dst));
10656    GGML_ASSERT(dst->type == GGML_TYPE_F32);
10657
10658    // TODO: handle transposed/permuted matrices
10659    const int64_t nc = src0->ne[0];
10660    const int64_t nr = ggml_nrows(src0);
10661
10662    const int ith = params->ith;
10663    const int nth = params->nth;
10664
10665    float * sums =  (float *) params->wdata;
10666    float * st   = ((float *) params->wdata) + nth + ith*nc;
10667    float sum_thread = 0.0f;
10668
10669    GGML_ASSERT(params->wsize >= sizeof(float) * (nth + nth * nc));
10670
10671    // rows per thread
10672    const int64_t dr = (nr + nth - 1)/nth;
10673
10674    // row range for this thread
10675    const int64_t ir0 = dr*ith;
10676    const int64_t ir1 = MIN(ir0 + dr, nr);
10677
10678    for (int64_t i1 = ir0; i1 < ir1; ++i1) {
10679        const float * s0 = (const float *)((const char *) src0->data + i1*src0->nb[1]);
10680        const float * s1 = (const float *)((const char *) src1->data + i1*src1->nb[1]);
10681
10682#ifndef NDEBUG
10683        for (int64_t i = 0; i < nc; ++i) {
10684            //printf("p[%d] = %f\n", i, p[i]);
10685            assert(!isnan(s0[i]));
10686            assert(!isnan(s1[i]));
10687        }
10688#endif
10689
10690        float max = -INFINITY;
10691        ggml_vec_max_f32(nc, &max, s0);
10692        const ggml_float sum_softmax = ggml_vec_log_soft_max_f32(nc, st, s0, max);
10693        assert(sum_softmax >= 0.0);
10694
10695        ggml_vec_add1_f32(nc, st, st, -sum_softmax);
10696        ggml_vec_mul_f32(nc, st, st, s1);
10697
10698        float sum_st = 0.0f;
10699        ggml_vec_sum_f32(nc, &sum_st, st);
10700        sum_thread += sum_st;
10701
10702#ifndef NDEBUG
10703        for (int64_t i = 0; i < nc; ++i) {
10704            assert(!isnan(st[i]));
10705            assert(!isinf(st[i]));
10706        }
10707#endif
10708    }
10709    sums[ith] = sum_thread;
10710    ggml_barrier(params->threadpool);
10711
10712    if (ith == 0) {
10713        float * dp = (float *) dst->data;
10714        ggml_vec_sum_f32(nth, dp, sums);
10715        dp[0] *= -1.0f / (float) nr;
10716    }
10717}
10718
10719void ggml_compute_forward_cross_entropy_loss(
10720        const ggml_compute_params * params,
10721        ggml_tensor * dst) {
10722
10723    const ggml_tensor * src0 = dst->src[0];
10724
10725    switch (src0->type) {
10726        case GGML_TYPE_F32:
10727            {
10728                ggml_compute_forward_cross_entropy_loss_f32(params, dst);
10729            } break;
10730        default:
10731            {
10732                GGML_ABORT("fatal error");
10733            }
10734    }
10735}
10736
10737// ggml_compute_forward_cross_entropy_loss_back
10738
10739static void ggml_compute_forward_cross_entropy_loss_back_f32(
10740        const ggml_compute_params * params,
10741        ggml_tensor * dst) {
10742
10743    const ggml_tensor * grad  = dst->src[0]; // gradient of forward pass output
10744    const ggml_tensor * src0f = dst->src[1]; // src0 of forward pass
10745    const ggml_tensor * src1f = dst->src[2]; // src1 of forward pass
10746
10747    GGML_ASSERT(ggml_is_contiguous(dst));
10748    GGML_ASSERT(ggml_is_contiguous(src0f));
10749    GGML_ASSERT(ggml_is_contiguous(src1f));
10750    GGML_ASSERT(ggml_is_contiguous(grad));
10751    GGML_ASSERT(ggml_are_same_shape(src0f, src1f) && ggml_are_same_shape(src0f, dst));
10752
10753    const int64_t ith = params->ith;
10754    const int64_t nth = params->nth;
10755
10756    // TODO: handle transposed/permuted matrices
10757    const int64_t nc = src0f->ne[0];
10758    const int64_t nr = ggml_nrows(src0f);
10759
10760    // rows per thread
10761    const int64_t dr = (nr + nth - 1)/nth;
10762
10763    // row range for this thread
10764    const int64_t ir0 = dr*ith;
10765    const int64_t ir1 = MIN(ir0 + dr, nr);
10766
10767    const float d_by_nr = ((const float *) grad->data)[0] / (float) nr;
10768
10769    for (int64_t i1 = ir0; i1 < ir1; i1++) {
10770        float       * ds0 = (float       *)((char       *) dst->data   + i1*dst->nb[1]);
10771        const float * s0  = (const float *)((const char *) src0f->data + i1*src0f->nb[1]);
10772        const float * s1  = (const float *)((const char *) src1f->data + i1*src1f->nb[1]);
10773
10774#ifndef NDEBUG
10775        for (int64_t i = 0; i < nc; ++i) {
10776            //printf("p[%d] = %f\n", i, p[i]);
10777            assert(!isnan(s0[i]));
10778            assert(!isnan(s1[i]));
10779        }
10780#endif
10781
10782        // soft_max
10783        float max = -INFINITY;
10784        ggml_vec_max_f32(nc, &max, s0);
10785        const ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
10786        assert(sum > 0.0);
10787        ggml_vec_scale_f32(nc, ds0, 1.0/sum);
10788
10789        // grad(src0f) = (softmax(src0f) - src1f) * grad(cross_entropy_loss(src0f, src1f)) / nr
10790        ggml_vec_sub_f32(nc, ds0, ds0, s1);
10791        ggml_vec_scale_f32(nc, ds0, d_by_nr);
10792
10793#ifndef NDEBUG
10794        for (int64_t i = 0; i < nc; ++i) {
10795            assert(!isnan(ds0[i]));
10796            assert(!isinf(ds0[i]));
10797        }
10798#endif
10799    }
10800}
10801
10802void ggml_compute_forward_cross_entropy_loss_back(
10803        const ggml_compute_params * params,
10804        ggml_tensor * dst) {
10805
10806    const ggml_tensor * src0 = dst->src[0];
10807
10808    switch (src0->type) {
10809        case GGML_TYPE_F32:
10810            {
10811                ggml_compute_forward_cross_entropy_loss_back_f32(params, dst);
10812            } break;
10813        default:
10814            {
10815                GGML_ABORT("fatal error");
10816            }
10817    }
10818}
10819
10820static void ggml_compute_forward_opt_step_adamw_f32(
10821        const ggml_compute_params * params,
10822        ggml_tensor * dst) {
10823
10824    const ggml_tensor * src0         = dst->src[0];
10825    const ggml_tensor * src0_grad    = dst->src[1];
10826    const ggml_tensor * src0_grad_m  = dst->src[2];
10827    const ggml_tensor * src0_grad_v  = dst->src[3];
10828    const ggml_tensor * adamw_params = dst->src[4];
10829
10830    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10831    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
10832    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
10833    GGML_ASSERT(ggml_nelements(adamw_params) == 7);
10834
10835    const int ith = params->ith;
10836    const int nth = params->nth;
10837
10838    const int nr  = ggml_nrows(src0);
10839
10840    GGML_TENSOR_UNARY_OP_LOCALS
10841    GGML_ASSERT(nb00 == sizeof(float));
10842
10843    // rows per thread
10844    const int dr = (nr + nth - 1)/nth;
10845
10846    // row range for this thread
10847    const int ir0 = dr*ith;
10848    const int ir1 = MIN(ir0 + dr, nr);
10849
10850    const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
10851
10852    const float alpha  = adamw_params_ptr[0];
10853    const float beta1  = adamw_params_ptr[1];
10854    const float beta2  = adamw_params_ptr[2];
10855    const float eps    = adamw_params_ptr[3];
10856    const float wd     = adamw_params_ptr[4];
10857    const float beta1h = adamw_params_ptr[5];
10858    const float beta2h = adamw_params_ptr[6];
10859    const float keep   = 1.f - alpha * wd;
10860    for (int ir = ir0; ir < ir1; ++ir) {
10861        const int64_t i03 = ir/(ne02*ne01);
10862        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
10863        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
10864
10865        const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
10866
10867        float       * w = (float       *) ((char       *) src0->data        + offset); // weight
10868        const float * g = (const float *) ((const char *) src0_grad->data   + offset); // grad
10869        float       * m = (float       *) ((char       *) src0_grad_m->data + offset);
10870        float       * v = (float       *) ((char       *) src0_grad_v->data + offset);
10871
10872        for (int i00 = 0; i00 < ne00; ++i00) {
10873            m[i00] = m[i00]*beta1 +        g[i00]*(1.0f - beta1);
10874            v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
10875
10876            const float mh =       m[i00]*beta1h;
10877            const float vh = sqrtf(v[i00]*beta2h) + eps;
10878
10879            // The weight decay is applied independently of the Adam momenta m and v.
10880            // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
10881            // See: https://arxiv.org/pdf/1711.05101v3.pdf
10882            w[i00] = w[i00] * keep - alpha * mh / vh;
10883        }
10884    }
10885}
10886
10887void ggml_compute_forward_opt_step_adamw(
10888        const ggml_compute_params * params,
10889        ggml_tensor * dst) {
10890
10891    const ggml_tensor * src0 = dst->src[0];
10892
10893    switch (src0->type) {
10894        case GGML_TYPE_F32:
10895            {
10896                ggml_compute_forward_opt_step_adamw_f32(params, dst);
10897            } break;
10898        default:
10899            {
10900                GGML_ABORT("fatal error");
10901            }
10902    }
10903}
10904
10905static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10906    const ggml_tensor * src0       = dst->src[0];
10907    const ggml_tensor * src0_grad  = dst->src[1];
10908    const ggml_tensor * sgd_params = dst->src[2];
10909
10910    GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
10911    GGML_ASSERT(ggml_nelements(sgd_params) == 2);
10912
10913    const int ith = params->ith;
10914    const int nth = params->nth;
10915
10916    const int nr = ggml_nrows(src0);
10917
10918    GGML_TENSOR_UNARY_OP_LOCALS
10919    GGML_ASSERT(nb00 == sizeof(float));
10920
10921    // rows per thread
10922    const int dr = (nr + nth - 1) / nth;
10923
10924    // row range for this thread
10925    const int ir0 = dr * ith;
10926    const int ir1 = MIN(ir0 + dr, nr);
10927
10928    // using adamw param subset we care about - alpha, wd - could have a separate struct
10929    const float * sgd_params_ptr   = ggml_get_data_f32(sgd_params);
10930    const float   alpha            = sgd_params_ptr[0];
10931    const float   keep             = 1.f - alpha * sgd_params_ptr[1];
10932
10933    for (int ir = ir0; ir < ir1; ++ir) {
10934        const int64_t i03 = ir / (ne02 * ne01);
10935        const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
10936        const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
10937
10938        const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
10939
10940        float *       w = (float *) ((char *) src0->data + offset);                   // weight
10941        const float * g = (const float *) ((const char *) src0_grad->data + offset);  // grad
10942
10943        for (int i00 = 0; i00 < ne00; ++i00) {
10944            w[i00] = w[i00] * keep - alpha * g[i00];
10945        }
10946    }
10947}
10948
10949void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
10950    const ggml_tensor * src0 = dst->src[0];
10951
10952    switch (src0->type) {
10953        case GGML_TYPE_F32:
10954            {
10955                ggml_compute_forward_opt_step_sgd_f32(params, dst);
10956            }
10957            break;
10958        default:
10959            {
10960                GGML_ABORT("fatal error - sgd is F32 only");
10961            }
10962    }
10963}