1#include "ggml-cuda.h"
   2#include "ggml-impl.h"
   3#include "ggml-backend-impl.h"
   4
   5#include "ggml-cuda/common.cuh"
   6#include "ggml-cuda/acc.cuh"
   7#include "ggml-cuda/add-id.cuh"
   8#include "ggml-cuda/arange.cuh"
   9#include "ggml-cuda/argmax.cuh"
  10#include "ggml-cuda/argsort.cuh"
  11#include "ggml-cuda/binbcast.cuh"
  12#include "ggml-cuda/clamp.cuh"
  13#include "ggml-cuda/concat.cuh"
  14#include "ggml-cuda/conv-transpose-1d.cuh"
  15#include "ggml-cuda/conv2d.cuh"
  16#include "ggml-cuda/conv2d-dw.cuh"
  17#include "ggml-cuda/conv2d-transpose.cuh"
  18#include "ggml-cuda/convert.cuh"
  19#include "ggml-cuda/count-equal.cuh"
  20#include "ggml-cuda/cpy.cuh"
  21#include "ggml-cuda/cross-entropy-loss.cuh"
  22#include "ggml-cuda/cumsum.cuh"
  23#include "ggml-cuda/diagmask.cuh"
  24#include "ggml-cuda/diag.cuh"
  25#include "ggml-cuda/fattn.cuh"
  26#include "ggml-cuda/getrows.cuh"
  27#include "ggml-cuda/im2col.cuh"
  28#include "ggml-cuda/mmf.cuh"
  29#include "ggml-cuda/mmq.cuh"
  30#include "ggml-cuda/mmvf.cuh"
  31#include "ggml-cuda/mmvq.cuh"
  32#include "ggml-cuda/norm.cuh"
  33#include "ggml-cuda/opt-step-adamw.cuh"
  34#include "ggml-cuda/opt-step-sgd.cuh"
  35#include "ggml-cuda/out-prod.cuh"
  36#include "ggml-cuda/pad.cuh"
  37#include "ggml-cuda/pool2d.cuh"
  38#include "ggml-cuda/quantize.cuh"
  39#include "ggml-cuda/rope.cuh"
  40#include "ggml-cuda/roll.cuh"
  41#include "ggml-cuda/scale.cuh"
  42#include "ggml-cuda/softcap.cuh"
  43#include "ggml-cuda/softmax.cuh"
  44#include "ggml-cuda/ssm-conv.cuh"
  45#include "ggml-cuda/ssm-scan.cuh"
  46#include "ggml-cuda/sum.cuh"
  47#include "ggml-cuda/sumrows.cuh"
  48#include "ggml-cuda/top-k.cuh"
  49#include "ggml-cuda/mean.cuh"
  50#include "ggml-cuda/tsembd.cuh"
  51#include "ggml-cuda/topk-moe.cuh"
  52#include "ggml-cuda/unary.cuh"
  53#include "ggml-cuda/upscale.cuh"
  54#include "ggml-cuda/wkv.cuh"
  55#include "ggml-cuda/gla.cuh"
  56#include "ggml-cuda/set.cuh"
  57#include "ggml-cuda/set-rows.cuh"
  58#include "ggml-cuda/pad_reflect_1d.cuh"
  59#include "ggml-cuda/solve_tri.cuh"
  60#include "ggml-cuda/tri.cuh"
  61#include "ggml-cuda/cumsum.cuh"
  62#include "ggml-cuda/fill.cuh"
  63#include "ggml.h"
  64
  65#include <algorithm>
  66#include <array>
  67#include <atomic>
  68#include <charconv>
  69#include <cinttypes>
  70#include <condition_variable>
  71#include <cstddef>
  72#include <cstdint>
  73#include <cfloat>
  74#include <initializer_list>
  75#include <limits>
  76#include <map>
  77#include <memory>
  78#include <mutex>
  79#include <cstdarg>
  80#include <cstdio>
  81#include <cstdlib>
  82#include <string>
  83#include <vector>
  84#include <unordered_set>
  85
  86static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  87
  88[[noreturn]]
  89void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
  90    int id = -1; // in case cudaGetDevice fails
  91    (void)cudaGetDevice(&id);
  92
  93    GGML_LOG_ERROR(GGML_CUDA_NAME " error: %s\n", msg);
  94    GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
  95    GGML_LOG_ERROR("  %s\n", stmt);
  96    // abort with GGML_ABORT to get a stack trace
  97    GGML_ABORT(GGML_CUDA_NAME " error");
  98}
  99
 100// this is faster on Windows
 101// probably because the Windows CUDA libraries forget to make this check before invoking the drivers
 102void ggml_cuda_set_device(int device) {
 103    int current_device;
 104    CUDA_CHECK(cudaGetDevice(&current_device));
 105
 106    if (device == current_device) {
 107        return;
 108    }
 109
 110    CUDA_CHECK(cudaSetDevice(device));
 111}
 112
 113int ggml_cuda_get_device() {
 114    int id;
 115    CUDA_CHECK(cudaGetDevice(&id));
 116    return id;
 117}
 118
 119static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
 120    ggml_cuda_set_device(device);
 121    cudaError_t err;
 122    if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) {
 123        err = cudaMallocManaged(ptr, size);
 124#if defined(GGML_USE_HIP)
 125        if (err == hipSuccess) {
 126            CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
 127        }
 128
 129        // fall back to cudaMalloc if not supported (e.g. on Windows)
 130        if (err == hipErrorNotSupported) {
 131            static bool warned_unsupported = false;
 132            if (!warned_unsupported) {
 133                GGML_LOG_WARN("hipMallocManaged unsupported, falling back to hipMalloc.\n");
 134                warned_unsupported = true;
 135            }
 136
 137            err = cudaMalloc(ptr, size);
 138        }
 139#endif // defined(GGML_USE_HIP)
 140    } else {
 141        err = cudaMalloc(ptr, size);
 142    }
 143    return err;
 144}
 145
 146#if defined(GGML_USE_HIP)
 147static int ggml_cuda_parse_id(char devName[]) {
 148    // A list of possible Target IDs can be found under the rocclr/clr repo in device.cpp
 149    // these values are not stable so this is susceptible to breakage
 150    // https://github.com/ROCm/clr/blob/amd-staging/rocclr/device/device.cpp
 151    int archMajor = 0x0;
 152    int archMinor = 0x0;
 153    int archNum = GGML_CUDA_CC_OFFSET_AMD;
 154    int archLen = strlen(devName);
 155    char archName[archLen + 1];
 156
 157    // strip leading 'gfx' while copying into our buffer
 158    if (archLen > 3) {
 159        strcpy(archName, &devName[3]);
 160        archLen -= 3;
 161    }
 162
 163    // trim trailing :xnack- or :sramecc- statuses
 164    archLen = strcspn(archName, ":");
 165    archName[archLen] = '\0';
 166
 167    // tease out the version information
 168    if (archLen > 8) {
 169        // versions labeled generic use '-' as delimiter
 170        // strip the trailing "-generic" then iterate through what remains
 171        if ((strstr(archName, "-generic"))) {
 172            archName[archLen - 8] = '\0';
 173            char * pch;
 174            if ((pch = strtok(archName, "-"))) {
 175                archMajor = (int)strtoul(pch, 0, 16);
 176                if ((pch = strtok(NULL, "-"))) {
 177                    archMinor = 0x10 * (int)strtoul(pch, 0, 16);
 178                }
 179            }
 180        }
 181    } else if (archLen >= 3) {
 182        // last two digits should be the minor * 0x10 + stepping
 183        archMinor = (int)strtoul(&archName[archLen - 2], 0, 16);
 184        archName[archLen - 2] = '\0';
 185
 186        // only the major version remains
 187        archMajor = (int)strtoul(archName, 0, 16);
 188    }
 189    archNum += archMajor * 0x100;
 190    archNum += archMinor;
 191    return archNum;
 192}
 193#endif // defined(GGML_USE_HIP)
 194
 195static ggml_cuda_device_info ggml_cuda_init() {
 196    ggml_cuda_device_info info = {};
 197
 198    cudaError_t err = cudaGetDeviceCount(&info.device_count);
 199    if (err != cudaSuccess) {
 200        GGML_LOG_ERROR("%s: failed to initialize " GGML_CUDA_NAME ": %s\n", __func__, cudaGetErrorString(err));
 201        return info;
 202    }
 203
 204    GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
 205
 206    int64_t total_vram = 0;
 207    GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
 208
 209    std::vector<std::pair<int, std::string>> turing_devices_without_mma;
 210    for (int id = 0; id < info.device_count; ++id) {
 211        int device_vmm = 0;
 212
 213#if defined(GGML_USE_VMM)
 214        CUdevice device;
 215        CU_CHECK(cuDeviceGet(&device, id));
 216        CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
 217
 218        if (device_vmm) {
 219            CUmemAllocationProp alloc_prop = {};
 220            alloc_prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
 221            alloc_prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
 222            alloc_prop.location.id = id;
 223            CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
 224        }
 225#endif // defined(GGML_USE_VMM)
 226        info.devices[id].vmm = !!device_vmm;
 227
 228        cudaDeviceProp prop;
 229        CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
 230
 231        info.default_tensor_split[id] = total_vram;
 232        total_vram += prop.totalGlobalMem;
 233        info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
 234        info.devices[id].nsm        = prop.multiProcessorCount;
 235        info.devices[id].smpb       = prop.sharedMemPerBlock;
 236        info.devices[id].warp_size  = prop.warpSize;
 237
 238#ifndef GGML_USE_MUSA
 239        int supports_coop_launch = 0;
 240        CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
 241        info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
 242#else
 243        info.devices[id].supports_cooperative_launch = false;
 244#endif // !(GGML_USE_MUSA)
 245#if defined(GGML_USE_HIP)
 246        info.devices[id].smpbo = prop.sharedMemPerBlock;
 247
 248        info.devices[id].cc = ggml_cuda_parse_id(prop.gcnArchName);
 249        if ((info.devices[id].cc & 0xff00) == 0x0) {
 250            GGML_LOG_WARN("invalid architecture ID received for device %d %s: %s  cc %d.%d\n",
 251                            id, prop.name, prop.gcnArchName, prop.major, prop.minor);
 252
 253            // Fallback to prop.major and prop.minor
 254            if (prop.major > 0) {
 255                info.devices[id].cc = GGML_CUDA_CC_OFFSET_AMD + prop.major * 0x100;
 256                info.devices[id].cc += prop.minor * 0x10;
 257            }
 258        }
 259        GGML_LOG_INFO("  Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n",
 260                      id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
 261                      device_vmm ? "yes" : "no", prop.warpSize);
 262#elif defined(GGML_USE_MUSA)
 263        // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
 264        info.devices[id].warp_size = 32;
 265        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
 266        info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
 267        info.devices[id].cc += prop.minor * 0x10;
 268        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
 269                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 270#else
 271        info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
 272        info.devices[id].cc = 100*prop.major + 10*prop.minor;
 273        GGML_LOG_INFO("  Device %d: %s, compute capability %d.%d, VMM: %s\n",
 274                        id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
 275        std::string device_name(prop.name);
 276        if (device_name == "NVIDIA GeForce MX450") {
 277            turing_devices_without_mma.push_back({ id, device_name });
 278        } else if (device_name == "NVIDIA GeForce MX550") {
 279            turing_devices_without_mma.push_back({ id, device_name });
 280        } else if (device_name.substr(0, 21) == "NVIDIA GeForce GTX 16") {
 281            turing_devices_without_mma.push_back({ id, device_name });
 282        }
 283
 284        // Temporary performance fix:
 285        // Setting device scheduling strategy for iGPUs with cc121 to "spinning" to avoid delays in cuda synchronize calls.
 286        // TODO: Check for future drivers the default scheduling strategy and
 287        // remove this call again when cudaDeviceScheduleSpin is default.
 288        if (prop.major == 12 && prop.minor == 1) {
 289            CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin));
 290        }
 291
 292#endif  // defined(GGML_USE_HIP)
 293    }
 294
 295    if (ggml_cuda_highest_compiled_arch(GGML_CUDA_CC_TURING) >= GGML_CUDA_CC_TURING && !turing_devices_without_mma.empty()) {
 296        GGML_LOG_INFO("The following devices will have suboptimal performance due to a lack of tensor cores:\n");
 297        for (size_t device_pos = 0; device_pos < turing_devices_without_mma.size(); device_pos++) {
 298            GGML_LOG_INFO(
 299                "  Device %d: %s\n", turing_devices_without_mma[device_pos].first, turing_devices_without_mma[device_pos].second.c_str());
 300        }
 301        GGML_LOG_INFO(
 302            "Consider compiling with CMAKE_CUDA_ARCHITECTURES=61-virtual;80-virtual and DGGML_CUDA_FORCE_MMQ to force the use of the Pascal code for Turing.\n");
 303    }
 304
 305    for (int id = 0; id < info.device_count; ++id) {
 306        info.default_tensor_split[id] /= total_vram;
 307    }
 308
 309    // configure logging to stdout
 310    // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
 311
 312    return info;
 313}
 314
 315const ggml_cuda_device_info & ggml_cuda_info() {
 316    static ggml_cuda_device_info info = ggml_cuda_init();
 317    return info;
 318}
 319
 320// #define DEBUG_CUDA_MALLOC
 321
 322// buffer pool for cuda (legacy)
 323struct ggml_cuda_pool_leg : public ggml_cuda_pool {
 324    static const int MAX_BUFFERS = 256;
 325
 326    int device;
 327    struct ggml_cuda_buffer {
 328        void * ptr = nullptr;
 329        size_t size = 0;
 330    };
 331
 332    ggml_cuda_buffer buffer_pool[MAX_BUFFERS] = {};
 333    size_t pool_size = 0;
 334
 335    explicit ggml_cuda_pool_leg(int device) :
 336        device(device) {
 337    }
 338
 339    ~ggml_cuda_pool_leg() {
 340        ggml_cuda_set_device(device);
 341        for (int i = 0; i < MAX_BUFFERS; ++i) {
 342            ggml_cuda_buffer & b = buffer_pool[i];
 343            if (b.ptr != nullptr) {
 344                CUDA_CHECK(cudaFree(b.ptr));
 345                pool_size -= b.size;
 346            }
 347        }
 348        GGML_ASSERT(pool_size == 0);
 349    }
 350
 351    void * alloc(size_t size, size_t * actual_size) override {
 352#ifdef DEBUG_CUDA_MALLOC
 353        int nnz = 0;
 354        size_t max_size = 0;
 355#endif
 356        size_t best_diff = 1ull << 36;
 357        int ibest = -1;
 358        for (int i = 0; i < MAX_BUFFERS; ++i) {
 359            ggml_cuda_buffer& b = buffer_pool[i];
 360            if (b.ptr != nullptr) {
 361#ifdef DEBUG_CUDA_MALLOC
 362                ++nnz;
 363                if (b.size > max_size) max_size = b.size;
 364#endif
 365                if (b.size >= size) {
 366                    size_t diff = b.size - size;
 367                    if (diff < best_diff) {
 368                        best_diff = diff;
 369                        ibest = i;
 370                        if (!best_diff) {
 371                            void * ptr = b.ptr;
 372                            *actual_size = b.size;
 373                            b.ptr = nullptr;
 374                            b.size = 0;
 375                            return ptr;
 376                        }
 377                    }
 378                }
 379            }
 380        }
 381        if (ibest >= 0) {
 382            ggml_cuda_buffer& b = buffer_pool[ibest];
 383            void * ptr = b.ptr;
 384            *actual_size = b.size;
 385            b.ptr = nullptr;
 386            b.size = 0;
 387            return ptr;
 388        }
 389        void * ptr;
 390        size_t look_ahead_size = (size_t) (1.05 * size);
 391        look_ahead_size = 256 * ((look_ahead_size + 255)/256);
 392        ggml_cuda_set_device(device);
 393        CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
 394        *actual_size = look_ahead_size;
 395        pool_size += look_ahead_size;
 396#ifdef DEBUG_CUDA_MALLOC
 397        GGML_LOG_INFO("%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, device, nnz,
 398                           (uint32_t)(max_size / 1024 / 1024), (uint32_t)(pool_size / 1024 / 1024), (uint32_t)(size / 1024 / 1024));
 399#endif
 400        return ptr;
 401    }
 402
 403    void free(void * ptr, size_t size) override {
 404        for (int i = 0; i < MAX_BUFFERS; ++i) {
 405            ggml_cuda_buffer& b = buffer_pool[i];
 406            if (b.ptr == nullptr) {
 407                b.ptr = ptr;
 408                b.size = size;
 409                return;
 410            }
 411        }
 412        GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
 413        ggml_cuda_set_device(device);
 414        CUDA_CHECK(cudaFree(ptr));
 415        pool_size -= size;
 416    }
 417};
 418
 419// pool with virtual memory
 420#if defined(GGML_USE_VMM)
 421struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
 422    static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
 423
 424    int device;
 425    CUdeviceptr pool_addr = 0;
 426    size_t pool_used = 0;
 427    size_t pool_size = 0;
 428    size_t granularity;
 429#if defined(GGML_USE_HIP)
 430    std::vector<std::pair<CUdeviceptr, size_t>> mappings;
 431#endif
 432
 433    explicit ggml_cuda_pool_vmm(int device) :
 434        device(device),
 435        granularity(ggml_cuda_info().devices[device].vmm_granularity) {
 436    }
 437
 438    ~ggml_cuda_pool_vmm() {
 439        if (pool_addr != 0) {
 440#if defined(GGML_USE_HIP)
 441            // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
 442            for (std::pair<CUdeviceptr, size_t> & mapping : mappings) {
 443                CU_CHECK(cuMemUnmap(mapping.first, mapping.second));
 444            }
 445#else
 446            CU_CHECK(cuMemUnmap(pool_addr, pool_size));
 447#endif
 448            CU_CHECK(cuMemAddressFree(pool_addr, CUDA_POOL_VMM_MAX_SIZE));
 449        }
 450    }
 451
 452    void * alloc(size_t size, size_t * actual_size) override {
 453        // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types
 454        const size_t alignment = 128;
 455        size = alignment * ((size + alignment - 1) / alignment);
 456
 457        size_t avail = pool_size - pool_used;
 458
 459        if (size > avail) {
 460            // round up to the next multiple of the granularity
 461            size_t reserve_size = size - avail;
 462            reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);
 463
 464            GGML_ASSERT(pool_size + reserve_size <= CUDA_POOL_VMM_MAX_SIZE);
 465
 466            // allocate more physical memory
 467            CUmemAllocationProp prop = {};
 468            prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
 469            prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
 470            prop.location.id = device;
 471            CUmemGenericAllocationHandle handle;
 472            CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));
 473
 474            // reserve virtual address space (if not already reserved)
 475            if (pool_addr == 0) {
 476                CU_CHECK(cuMemAddressReserve(&pool_addr, CUDA_POOL_VMM_MAX_SIZE, 0, 0, 0));
 477            }
 478
 479            // map at the end of the pool
 480            CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
 481            CU_CHECK(cuMemMap(start_ptr, reserve_size, 0, handle, 0));
 482#if defined(GGML_USE_HIP)
 483            mappings.push_back({start_ptr, reserve_size});
 484#endif
 485
 486            // the memory allocation handle is no longer needed after mapping
 487            CU_CHECK(cuMemRelease(handle));
 488
 489            // set access
 490            CUmemAccessDesc access = {};
 491            access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
 492            access.location.id = device;
 493            access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
 494            CU_CHECK(cuMemSetAccess((CUdeviceptr)((char *)(pool_addr) + pool_size), reserve_size, &access, 1));
 495
 496            // add to the pool
 497            pool_size += reserve_size;
 498
 499            //printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
 500            //       device, (unsigned long long) (pool_size/1024/1024),
 501            //       (unsigned long long) (reserve_size/1024/1024));
 502        }
 503
 504        GGML_ASSERT(pool_addr != 0);
 505
 506        void * ptr = (void *) ((CUdeviceptr)((char *)(pool_addr) + pool_used));
 507        *actual_size = size;
 508        pool_used += size;
 509
 510#ifdef DEBUG_CUDA_MALLOC
 511        printf("cuda pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
 512#endif
 513
 514        return ptr;
 515    }
 516
 517    void free(void * ptr, size_t size) override {
 518#ifdef DEBUG_CUDA_MALLOC
 519        printf("cuda pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size, ptr);
 520#endif
 521
 522        pool_used -= size;
 523
 524        // all deallocations must be in reverse order of the allocations
 525        GGML_ASSERT(ptr == (void *) ((char *)(pool_addr) + pool_used));
 526    }
 527};
 528#endif // defined(GGML_USE_VMM)
 529
 530std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int                  device,
 531                                                                               [[maybe_unused]] int stream_no) {
 532#if defined(GGML_USE_VMM)
 533    if (ggml_cuda_info().devices[device].vmm) {
 534        return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
 535    }
 536#endif // defined(GGML_USE_VMM)
 537    return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
 538}
 539
 540// destroying a cuBLAS handle while a graph is being captured in a different thread can result in a CUDA error
 541// this lock is used to ensure that no cuBLAS handle is destroyed while a graph is being captured
 542
 543static std::mutex ggml_cuda_lock;
 544static std::condition_variable ggml_cuda_lock_cv;
 545static std::atomic<int> ggml_cuda_lock_counter;
 546
 547ggml_backend_cuda_context::~ggml_backend_cuda_context() {
 548    std::unique_lock<std::mutex> lock(ggml_cuda_lock);
 549    ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
 550
 551    if (copy_event != nullptr) {
 552        CUDA_CHECK(cudaEventDestroy(copy_event));
 553    }
 554    for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
 555        for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
 556            if (streams[i][j] != nullptr) {
 557                CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
 558            }
 559        }
 560        if (cublas_handles[i] != nullptr) {
 561            CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
 562        }
 563    }
 564}
 565
 566
 567// cuda buffer
 568
 569struct ggml_backend_cuda_buffer_context {
 570    int device;
 571    void * dev_ptr = nullptr;
 572    std::string name;
 573
 574    ggml_backend_cuda_buffer_context(int device, void * dev_ptr) :
 575        device(device), dev_ptr(dev_ptr),
 576        name(GGML_CUDA_NAME + std::to_string(device)) {
 577    }
 578
 579    ~ggml_backend_cuda_buffer_context() {
 580        CUDA_CHECK(cudaFree(dev_ptr));
 581    }
 582};
 583
 584static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 585    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 586    delete ctx;
 587}
 588
 589static bool ggml_backend_buffer_is_cuda(ggml_backend_buffer_t buffer) {
 590    return buffer->iface.free_buffer == ggml_backend_cuda_buffer_free_buffer;
 591}
 592
 593static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
 594    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 595    return ctx->dev_ptr;
 596}
 597
 598static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
 599    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 600
 601    if (tensor->view_src != NULL) {
 602        assert(tensor->view_src->buffer->buft == buffer->buft);
 603        return GGML_STATUS_SUCCESS;
 604    }
 605
 606    if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
 607        // initialize padding to 0 to avoid possible NaN values
 608        const size_t original_size = ggml_nbytes(tensor);
 609        const size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 610
 611        if (padded_size > original_size) {
 612            ggml_cuda_set_device(ctx->device);
 613            CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
 614        }
 615    }
 616    return GGML_STATUS_SUCCESS;
 617}
 618
 619static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
 620    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 621
 622    ggml_cuda_set_device(ctx->device);
 623    CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread));
 624    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 625}
 626
 627static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
 628    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 629
 630    ggml_cuda_set_device(ctx->device);
 631    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread));
 632    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 633}
 634
 635static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 636    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 637
 638    ggml_cuda_set_device(ctx->device);
 639    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
 640    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 641}
 642
 643static bool ggml_backend_cuda_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
 644    if (ggml_backend_buffer_is_cuda(src->buffer)) {
 645        ggml_backend_cuda_buffer_context * src_ctx = (ggml_backend_cuda_buffer_context *)src->buffer->context;
 646        ggml_backend_cuda_buffer_context * dst_ctx = (ggml_backend_cuda_buffer_context *)dst->buffer->context;
 647        if (src_ctx->device == dst_ctx->device) {
 648            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(src), cudaMemcpyDeviceToDevice, cudaStreamPerThread));
 649        } else {
 650#ifdef GGML_CUDA_NO_PEER_COPY
 651            return false;
 652#else
 653            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, dst_ctx->device, src->data, src_ctx->device, ggml_nbytes(src), cudaStreamPerThread));
 654#endif
 655        }
 656        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 657        return true;
 658    }
 659    return false;
 660
 661    GGML_UNUSED(buffer);
 662}
 663
 664static void ggml_backend_cuda_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
 665    ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
 666
 667    ggml_cuda_set_device(ctx->device);
 668    CUDA_CHECK(cudaMemsetAsync(ctx->dev_ptr, value, buffer->size, cudaStreamPerThread));
 669    CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 670}
 671
 672static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = {
 673    /* .free_buffer     = */ ggml_backend_cuda_buffer_free_buffer,
 674    /* .get_base        = */ ggml_backend_cuda_buffer_get_base,
 675    /* .init_tensor     = */ ggml_backend_cuda_buffer_init_tensor,
 676    /* .memset_tensor   = */ ggml_backend_cuda_buffer_memset_tensor,
 677    /* .set_tensor      = */ ggml_backend_cuda_buffer_set_tensor,
 678    /* .get_tensor      = */ ggml_backend_cuda_buffer_get_tensor,
 679    /* .cpy_tensor      = */ ggml_backend_cuda_buffer_cpy_tensor,
 680    /* .clear           = */ ggml_backend_cuda_buffer_clear,
 681    /* .reset           = */ NULL,
 682};
 683
 684// cuda buffer type
 685struct ggml_backend_cuda_buffer_type_context {
 686    int device;
 687    std::string name;
 688};
 689
 690static const char * ggml_backend_cuda_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
 691    ggml_backend_cuda_buffer_type_context * ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
 692
 693    return ctx->name.c_str();
 694}
 695
 696static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
 697    return buft->iface.get_name == ggml_backend_cuda_buffer_type_get_name;
 698}
 699
 700static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
 701    ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
 702
 703    ggml_cuda_set_device(buft_ctx->device);
 704
 705    void * dev_ptr;
 706    cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
 707    if (err != cudaSuccess) {
 708        // clear the error
 709        (void)cudaGetLastError();
 710        GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n", __func__, size / 1024.0 / 1024.0, buft_ctx->device, cudaGetErrorString(err));
 711        return nullptr;
 712    }
 713
 714    ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr);
 715
 716    return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size);
 717}
 718
 719static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
 720    return 128;
 721
 722    GGML_UNUSED(buft);
 723}
 724
 725static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
 726    size_t size = ggml_nbytes(tensor);
 727    int64_t ne0 = tensor->ne[0];
 728
 729    if (ggml_is_quantized(tensor->type)) {
 730        if (ne0 % MATRIX_ROW_PADDING != 0) {
 731            GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor));
 732            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 733        }
 734    }
 735
 736    return size;
 737
 738    GGML_UNUSED(buft);
 739}
 740
 741static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
 742    /* .get_name         = */ ggml_backend_cuda_buffer_type_get_name,
 743    /* .alloc_buffer     = */ ggml_backend_cuda_buffer_type_alloc_buffer,
 744    /* .get_alignment    = */ ggml_backend_cuda_buffer_type_get_alignment,
 745    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
 746    /* .get_alloc_size   = */ ggml_backend_cuda_buffer_type_get_alloc_size,
 747    /* .is_host          = */ NULL,
 748};
 749
 750ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) {
 751    static std::mutex mutex;
 752    std::lock_guard<std::mutex> lock(mutex);
 753
 754    if (device >= ggml_backend_cuda_get_device_count()) {
 755        return nullptr;
 756    }
 757
 758    static ggml_backend_buffer_type ggml_backend_cuda_buffer_types[GGML_CUDA_MAX_DEVICES];
 759
 760    static bool ggml_backend_cuda_buffer_type_initialized = false;
 761
 762    if (!ggml_backend_cuda_buffer_type_initialized) {
 763        for (int i = 0; i < ggml_backend_cuda_get_device_count(); i++) {
 764            ggml_backend_cuda_buffer_types[i] = {
 765                /* .iface    = */ ggml_backend_cuda_buffer_type_interface,
 766                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), i),
 767                /* .context  = */ new ggml_backend_cuda_buffer_type_context{i, GGML_CUDA_NAME + std::to_string(i)},
 768            };
 769        }
 770        ggml_backend_cuda_buffer_type_initialized = true;
 771    }
 772
 773    return &ggml_backend_cuda_buffer_types[device];
 774}
 775
 776// cuda split buffer
 777
 778static int64_t get_row_rounding(const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
 779    int64_t row_rounding = 0;
 780    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 781        if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
 782            continue;
 783        }
 784
 785        const int cc = ggml_cuda_info().devices[id].cc;
 786        row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc));
 787    }
 788    return row_rounding;
 789}
 790
 791static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split, int id) {
 792    const int64_t nrows = ggml_nrows(tensor);
 793    const int64_t rounding = get_row_rounding(tensor_split);
 794
 795    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
 796    *row_low -= *row_low % rounding;
 797
 798    if (id == ggml_backend_cuda_get_device_count() - 1) {
 799        *row_high = nrows;
 800    } else {
 801        *row_high = nrows*tensor_split[id + 1];
 802        *row_high -= *row_high % rounding;
 803    }
 804}
 805
 806static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
 807    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 808
 809    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
 810}
 811
 812struct ggml_backend_cuda_split_buffer_type_context {
 813    int main_device;
 814    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
 815    std::string name;
 816};
 817
 818struct ggml_backend_cuda_split_buffer_context {
 819    ~ggml_backend_cuda_split_buffer_context() {
 820        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
 821            for (int id = 0; id < GGML_CUDA_MAX_DEVICES; ++id) {
 822                for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
 823                    if (extra->events[id][is] != nullptr) {
 824                        CUDA_CHECK(cudaEventDestroy(extra->events[id][is]));
 825                    }
 826                }
 827                if (extra->data_device[id] != nullptr) {
 828                    CUDA_CHECK(cudaFree(extra->data_device[id]));
 829                }
 830            }
 831            delete extra;
 832        }
 833    }
 834
 835    std::vector<ggml_tensor_extra_gpu *> tensor_extras;
 836};
 837
 838
 839static void ggml_backend_cuda_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 840    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
 841    delete ctx;
 842}
 843
 844static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buffer) {
 845    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
 846    return (void *)0x1000;
 847
 848    GGML_UNUSED(buffer);
 849}
 850
 851static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
 852    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
 853    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 854
 855    ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
 856    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
 857
 858    const int64_t ne0 = tensor->ne[0];
 859
 860    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
 861    ctx->tensor_extras.push_back(extra);
 862
 863    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 864        int64_t row_low, row_high;
 865        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
 866
 867        int64_t nrows_split = row_high - row_low;
 868        if (nrows_split == 0) {
 869            continue;
 870        }
 871
 872        size_t size = ggml_nbytes_split(tensor, nrows_split);
 873        const size_t original_size = size;
 874
 875        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
 876        if (ne0 % MATRIX_ROW_PADDING != 0) {
 877            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 878        }
 879
 880        // FIXME: do not crash if cudaMalloc fails
 881        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
 882        ggml_cuda_set_device(id);
 883        char * buf;
 884        CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
 885
 886        // set padding to 0 to avoid possible NaN values
 887        if (size > original_size) {
 888            CUDA_CHECK(cudaMemset(buf + original_size, 0, size - original_size));
 889        }
 890
 891        extra->data_device[id] = buf;
 892
 893        for (int64_t is = 0; is < GGML_CUDA_MAX_STREAMS; ++is) {
 894            CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id][is], cudaEventDisableTiming));
 895        }
 896    }
 897    tensor->extra = extra;
 898    return GGML_STATUS_SUCCESS;
 899}
 900
 901static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
 902    // split tensors must always be set in their entirety at once
 903    GGML_ASSERT(offset == 0);
 904    GGML_ASSERT(size == ggml_nbytes(tensor));
 905    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 906
 907    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
 908
 909    const int64_t ne0 = tensor->ne[0];
 910    const size_t nb1 = tensor->nb[1];
 911    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
 912
 913    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 914        int64_t row_low, row_high;
 915        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
 916
 917        int64_t nrows_split = row_high - row_low;
 918        if (nrows_split == 0) {
 919            continue;
 920        }
 921
 922        const size_t offset_split = row_low*nb1;
 923        size_t size = ggml_nbytes_split(tensor, nrows_split);
 924        const size_t original_size = size;
 925
 926        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
 927        if (ne0 % MATRIX_ROW_PADDING != 0) {
 928            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 929        }
 930
 931        const char * buf_host = (const char *)data + offset_split;
 932        CUDA_CHECK(cudaMemcpyAsync(extra->data_device[id], buf_host, original_size, cudaMemcpyHostToDevice, cudaStreamPerThread));
 933    }
 934
 935    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 936        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 937    }
 938}
 939
 940static void ggml_backend_cuda_split_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 941    // split tensors must always be set in their entirety at once
 942    GGML_ASSERT(offset == 0);
 943    GGML_ASSERT(size == ggml_nbytes(tensor));
 944    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
 945
 946    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *)buffer->buft->context;
 947
 948    const int64_t ne0 = tensor->ne[0];
 949    const size_t nb1 = tensor->nb[1];
 950    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
 951
 952    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 953        int64_t row_low, row_high;
 954        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, id);
 955
 956        int64_t nrows_split = row_high - row_low;
 957        if (nrows_split == 0) {
 958            continue;
 959        }
 960
 961        const size_t offset_split = row_low*nb1;
 962        size_t size = ggml_nbytes_split(tensor, nrows_split);
 963        const size_t original_size = size;
 964
 965        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
 966        if (ne0 % MATRIX_ROW_PADDING != 0) {
 967            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
 968        }
 969
 970        char * buf_host = (char *)data + offset_split;
 971        CUDA_CHECK(cudaMemcpyAsync(buf_host, extra->data_device[id], original_size, cudaMemcpyDeviceToHost, cudaStreamPerThread));
 972    }
 973
 974    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
 975        CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread));
 976    }
 977}
 978
 979static void ggml_backend_cuda_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
 980    GGML_UNUSED(buffer);
 981    GGML_UNUSED(value);
 982}
 983
 984static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = {
 985    /* .free_buffer     = */ ggml_backend_cuda_split_buffer_free_buffer,
 986    /* .get_base        = */ ggml_backend_cuda_split_buffer_get_base,
 987    /* .init_tensor     = */ ggml_backend_cuda_split_buffer_init_tensor,
 988    /* .memset_tensor   = */ NULL,
 989    /* .set_tensor      = */ ggml_backend_cuda_split_buffer_set_tensor,
 990    /* .get_tensor      = */ ggml_backend_cuda_split_buffer_get_tensor,
 991    /* .cpy_tensor      = */ NULL,
 992    /* .clear           = */ ggml_backend_cuda_split_buffer_clear,
 993    /* .reset           = */ NULL,
 994};
 995
 996// cuda split buffer type
 997
 998static const char * ggml_backend_cuda_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
 999    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
1000
1001    return ctx->name.c_str();
1002}
1003
1004static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
1005    return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_get_name;
1006}
1007
1008static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1009    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
1010    // instead, we allocate them for each tensor separately in init_tensor
1011    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
1012    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
1013    ggml_backend_cuda_split_buffer_context * ctx = new ggml_backend_cuda_split_buffer_context();
1014
1015    return ggml_backend_buffer_init(buft, ggml_backend_cuda_split_buffer_interface, ctx, size);
1016}
1017
1018static size_t ggml_backend_cuda_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1019    return 128;
1020
1021    GGML_UNUSED(buft);
1022}
1023
1024static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
1025    ggml_backend_cuda_split_buffer_type_context * ctx = (ggml_backend_cuda_split_buffer_type_context *)buft->context;
1026    GGML_ASSERT(ggml_is_contiguous(tensor) && "split buffers only supported for contiguous tensors");
1027
1028    size_t total_size = 0;
1029
1030    const int64_t ne0 = tensor->ne[0];
1031
1032    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1033        int64_t row_low, row_high;
1034        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, id);
1035
1036        int64_t nrows_split = row_high - row_low;
1037        if (nrows_split == 0) {
1038            continue;
1039        }
1040
1041        total_size += ggml_nbytes_split(tensor, nrows_split);
1042
1043        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
1044        if (ne0 % MATRIX_ROW_PADDING != 0) {
1045            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1046        }
1047    }
1048
1049    return total_size;
1050}
1051
1052static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1053    return false;
1054
1055    GGML_UNUSED(buft);
1056}
1057
1058static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface = {
1059    /* .get_name         = */ ggml_backend_cuda_split_buffer_type_get_name,
1060    /* .alloc_buffer     = */ ggml_backend_cuda_split_buffer_type_alloc_buffer,
1061    /* .get_alignment    = */ ggml_backend_cuda_split_buffer_type_get_alignment,
1062    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
1063    /* .get_alloc_size   = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
1064    /* .is_host          = */ ggml_backend_cuda_split_buffer_type_is_host,
1065};
1066
1067ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) {
1068    static std::mutex mutex;
1069    std::lock_guard<std::mutex> lock(mutex);
1070
1071    static std::map<std::pair<int, std::array<float, GGML_CUDA_MAX_DEVICES>>, struct ggml_backend_buffer_type> buft_map;
1072
1073    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split_arr = {};
1074
1075    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_CUDA_MAX_DEVICES, [](float x) { return x == 0.0f; });
1076    if (all_zero) {
1077        tensor_split_arr = ggml_cuda_info().default_tensor_split;
1078    } else {
1079        float split_sum = 0.0f;
1080        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
1081            tensor_split_arr[i] = split_sum;
1082            split_sum += tensor_split[i];
1083        }
1084        for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) {
1085            tensor_split_arr[i] /= split_sum;
1086        }
1087    }
1088
1089    auto it = buft_map.find({main_device, tensor_split_arr});
1090    if (it != buft_map.end()) {
1091        return &it->second;
1092    }
1093    auto * ctx = new ggml_backend_cuda_split_buffer_type_context{
1094        main_device,
1095        tensor_split_arr,
1096        GGML_CUDA_NAME + std::to_string(main_device) + "_Split",
1097    };
1098
1099    struct ggml_backend_buffer_type buft {
1100        /* .iface   = */ ggml_backend_cuda_split_buffer_type_interface,
1101        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), main_device),
1102        /* .context = */ ctx,
1103    };
1104
1105    auto result = buft_map.emplace(std::make_pair(main_device, tensor_split_arr), buft);
1106    return &result.first->second;
1107}
1108
1109// host buffer type
1110
1111static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1112    return GGML_CUDA_NAME "_Host";
1113
1114    GGML_UNUSED(buft);
1115}
1116
1117static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
1118    return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
1119}
1120
1121static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1122    CUDA_CHECK(cudaFreeHost(buffer->context));
1123}
1124
1125static void * ggml_cuda_host_malloc(size_t size) {
1126    if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
1127        return nullptr;
1128    }
1129
1130    void * ptr = nullptr;
1131    cudaError_t err = cudaMallocHost((void **) &ptr, size);
1132    if (err != cudaSuccess) {
1133        // clear the error
1134        (void)cudaGetLastError();
1135        GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1136                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
1137        return nullptr;
1138    }
1139
1140    return ptr;
1141}
1142
1143static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1144    void * ptr = ggml_cuda_host_malloc(size);
1145
1146    if (ptr == nullptr) {
1147        // fallback to cpu buffer
1148        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1149    }
1150
1151    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
1152    buffer->buft = buft;
1153    buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer;
1154
1155    return buffer;
1156}
1157
1158ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
1159    static struct ggml_backend_buffer_type ggml_backend_cuda_buffer_type_host = {
1160        /* .iface    = */ {
1161            /* .get_name         = */ ggml_backend_cuda_host_buffer_type_name,
1162            /* .alloc_buffer     = */ ggml_backend_cuda_host_buffer_type_alloc_buffer,
1163            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1164            /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
1165            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1166            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1167        },
1168        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), 0),
1169        /* .context  = */ nullptr,
1170    };
1171
1172    return &ggml_backend_cuda_buffer_type_host;
1173}
1174
1175//static bool ggml_backend_buffer_is_cuda_host(ggml_backend_buffer_t buffer) {
1176//    return buffer->buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
1177//}
1178
1179/// kernels
1180
1181typedef void (*ggml_cuda_op_mul_mat_t)(
1182    ggml_backend_cuda_context & ctx,
1183    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1184    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
1185    const int64_t src1_padded_row_size, cudaStream_t stream);
1186
1187#ifndef GGML_CUDA_PEER_MAX_BATCH_SIZE
1188#define GGML_CUDA_PEER_MAX_BATCH_SIZE 128
1189#endif // GGML_CUDA_PEER_MAX_BATCH_SIZE
1190
1191#define MUL_MAT_SRC1_COL_STRIDE 128
1192
1193static cudaError_t ggml_cuda_cpy_tensor_2d(
1194    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
1195
1196    const char * src_ptr = (const char *) src->data;
1197    char       * dst_ptr = (char       *) dst;
1198
1199    const int64_t ne0 = src->ne[0];
1200    const int64_t nb0 = src->nb[0];
1201    const int64_t nb1 = src->nb[1];
1202    const int64_t nb2 = src->nb[2];
1203    const int64_t nb3 = src->nb[3];
1204    const enum ggml_type type = src->type;
1205    const int64_t ts = ggml_type_size(type);
1206    const int64_t bs = ggml_blck_size(type);
1207    const int64_t i1_diff = i1_high - i1_low;
1208
1209    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
1210    if (nb0 == ts && nb1 == ts*ne0/bs) {
1211        return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, cudaMemcpyDeviceToDevice, stream);
1212    } else if (nb0 == ts) {
1213        return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyDeviceToDevice, stream);
1214    } else {
1215        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
1216            const void * rx = (const void *) ((const char *) x + i1*nb1);
1217            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
1218            // pretend the row is a matrix with cols=1
1219            cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyDeviceToDevice, stream);
1220            if (r != cudaSuccess) {
1221                return r;
1222            }
1223        }
1224        return cudaSuccess;
1225    }
1226}
1227
1228static void ggml_cuda_op_mul_mat_cublas(
1229    ggml_backend_cuda_context & ctx,
1230    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
1231    const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
1232    const int64_t src1_padded_row_size, cudaStream_t stream) {
1233
1234    GGML_ASSERT(src0_dd_i  != nullptr);
1235    GGML_ASSERT(src1_ddf_i != nullptr);
1236    GGML_ASSERT(dst_dd_i   != nullptr);
1237
1238    const int64_t ne00 = src0->ne[0];
1239    const int64_t ne10 = src1->ne[0];
1240
1241    const int64_t ne0 = dst->ne[0];
1242
1243    const int64_t row_diff = row_high - row_low;
1244
1245    int id = ggml_cuda_get_device();
1246
1247    // the main device has a larger memory buffer to hold the results from all GPUs
1248    // ldc == nrows of the matrix that cuBLAS writes into
1249    int64_t ldc = id == ctx.device ? ne0 : row_diff;
1250
1251    const int cc = ggml_cuda_info().devices[id].cc;
1252
1253    const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
1254        (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
1255
1256    const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
1257
1258    if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
1259        ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
1260        if (src1->type != GGML_TYPE_BF16) {
1261            const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
1262            GGML_ASSERT(to_bf16_cuda != nullptr);
1263            size_t ne = src1_ncols*ne10;
1264            src1_as_bf16.alloc(ne);
1265            to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
1266        }
1267        const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
1268        const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
1269        ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
1270
1271        const float alpha_f32 = 1.0f;
1272        const float beta_f32  = 0.0f;
1273
1274        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1275        CUBLAS_CHECK(
1276            cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1277                    row_diff, src1_ncols, ne10,
1278                    &alpha_f32,  src0_ptr,       CUDA_R_16BF, ne00,
1279                                 src1_ptr,       CUDA_R_16BF, ne10,
1280                    &beta_f32,   dst_bf16.get(), CUDA_R_16BF, ldc,
1281                    CUBLAS_COMPUTE_32F,
1282                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1283
1284        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
1285        to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1286    } else if (fast_fp16_hardware_available(cc) && use_fp16) {
1287        // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1288        ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1289        if (src0->type != GGML_TYPE_F16) {
1290            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
1291            GGML_ASSERT(to_fp16_cuda != nullptr);
1292            size_t ne = row_diff*ne00;
1293            src0_as_f16.alloc(ne);
1294            to_fp16_cuda(src0_dd_i, src0_as_f16.get(), ne, stream);
1295        }
1296        const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
1297
1298        ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
1299        if (src1->type != GGML_TYPE_F16) {
1300            const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
1301            GGML_ASSERT(to_fp16_cuda != nullptr);
1302            size_t ne = src1_ncols*ne10;
1303            src1_as_f16.alloc(ne);
1304            to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
1305        }
1306        const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
1307
1308        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1309
1310        if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1311            const float alpha = 1.0f;
1312            const float beta = 0.0f;
1313            CUBLAS_CHECK(
1314                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1315                        row_diff, src1_ncols, ne10,
1316                        &alpha, src0_ptr,  CUDA_R_16F, ne00,
1317                                src1_ptr,  CUDA_R_16F, ne10,
1318                        &beta,   dst_dd_i, CUDA_R_32F, ldc,
1319                        CUBLAS_COMPUTE_32F,
1320                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1321        } else {
1322            ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
1323
1324            const half alpha_f16 = 1.0f;
1325            const half beta_f16 = 0.0f;
1326
1327            CUBLAS_CHECK(
1328                cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1329                        row_diff, src1_ncols, ne10,
1330                        &alpha_f16, src0_ptr,      CUDA_R_16F, ne00,
1331                                    src1_ptr,      CUDA_R_16F, ne10,
1332                        &beta_f16,  dst_f16.get(), CUDA_R_16F, ldc,
1333                        CUBLAS_COMPUTE_16F,
1334                        CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1335
1336            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
1337            to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
1338        }
1339    } else {
1340        ggml_cuda_pool_alloc<float> src0_ddq_as_f32(ctx.pool(id));
1341        ggml_cuda_pool_alloc<float> src1_ddq_as_f32(ctx.pool(id));
1342
1343        if (src0->type != GGML_TYPE_F32) {
1344            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
1345            GGML_ASSERT(to_fp32_cuda != nullptr);
1346            src0_ddq_as_f32.alloc(row_diff*ne00);
1347            to_fp32_cuda(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
1348        }
1349        if (src1->type != GGML_TYPE_F32) {
1350            const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src1->type);
1351            GGML_ASSERT(to_fp32_cuda != nullptr);
1352            src1_ddq_as_f32.alloc(src1_ncols*ne10);
1353            to_fp32_cuda(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
1354        }
1355
1356        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
1357        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
1358
1359        const float alpha = 1.0f;
1360        const float beta = 0.0f;
1361
1362        CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
1363        CUBLAS_CHECK(
1364            cublasSgemm(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
1365                    row_diff, src1_ncols, ne10,
1366                    &alpha, src0_ddf_i,  ne00,
1367                            src1_ddf1_i, ne10,
1368                    &beta,  dst_dd_i,    ldc));
1369    }
1370
1371    GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size);
1372}
1373
1374static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
1375    static bool peer_access_enabled = false;
1376
1377    const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE;
1378
1379    if (peer_access_enabled == enable_peer_access) {
1380        return;
1381    }
1382
1383#ifdef NDEBUG
1384    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1385        ggml_cuda_set_device(id);
1386        CUDA_CHECK(cudaDeviceSynchronize());
1387    }
1388
1389    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1390        ggml_cuda_set_device(id);
1391
1392        for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) {
1393            if (id == id_other) {
1394                continue;
1395            }
1396            if (id != main_device && id_other != main_device) {
1397                continue;
1398            }
1399
1400            int can_access_peer;
1401            CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other));
1402            if (can_access_peer) {
1403                if (enable_peer_access) {
1404                    cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0);
1405                    if (err != cudaErrorPeerAccessAlreadyEnabled) {
1406                        CUDA_CHECK(err);
1407                    } else {
1408                        // reset the error
1409                        (void)cudaGetLastError();
1410                    }
1411                } else {
1412                    cudaError_t err = cudaDeviceDisablePeerAccess(id_other);
1413                    if (err != cudaErrorPeerAccessNotEnabled) {
1414                        CUDA_CHECK(err);
1415                    } else {
1416                        // reset the error
1417                        (void)cudaGetLastError();
1418                    }
1419                }
1420            }
1421        }
1422    }
1423
1424    ggml_cuda_set_device(main_device);
1425#endif // NDEBUG
1426
1427    peer_access_enabled = enable_peer_access;
1428
1429    GGML_UNUSED(main_device);
1430}
1431
1432static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
1433    void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
1434
1435#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1436    // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
1437    cudaMemcpy3DPeerParms p = {};
1438    p.dstDevice = dstDevice;
1439    p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
1440    p.srcDevice = srcDevice;
1441    p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
1442    p.extent = make_cudaExtent(width, height, 1);
1443    return cudaMemcpy3DPeerAsync(&p, stream);
1444#else
1445    // HIP does not support cudaMemcpy3DPeerAsync or vmm pools
1446    GGML_UNUSED(dstDevice);
1447    GGML_UNUSED(srcDevice);
1448    return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
1449#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1450}
1451
1452static void ggml_cuda_op_mul_mat(
1453    ggml_backend_cuda_context & ctx,
1454    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
1455    quantize_cuda_t quantize_src1) {
1456
1457    const int64_t ne00 = src0->ne[0];
1458    const int64_t ne01 = src0->ne[1];
1459    const int64_t ne02 = src0->ne[2];
1460    const int64_t ne03 = src0->ne[3];
1461
1462    const int64_t ne10 = src1->ne[0];
1463    const int64_t ne11 = src1->ne[1];
1464    const int64_t ne12 = src1->ne[2];
1465    const int64_t ne13 = src1->ne[3];
1466    const int64_t nrows1 = ggml_nrows(src1);
1467
1468    const int64_t ne0 = dst->ne[0];
1469    const int64_t ne1 = dst->ne[1];
1470
1471    // const int64_t nb10 = src1->nb[0];
1472    const int64_t nb11 = src1->nb[1];
1473    const int64_t nb12 = src1->nb[2];
1474    const int64_t nb13 = src1->nb[3];
1475
1476    const int64_t nb2 = dst->nb[2];
1477    const int64_t nb3 = dst->nb[3];
1478
1479    ggml_backend_cuda_buffer_context * src1_ctx = (ggml_backend_cuda_buffer_context *) src1->buffer->context;
1480    ggml_backend_cuda_buffer_context * dst_ctx  = (ggml_backend_cuda_buffer_context *) dst->buffer->context;
1481
1482    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
1483
1484    GGML_ASSERT(ne12 % ne02 == 0);
1485    GGML_ASSERT(ne13 % ne03 == 0);
1486
1487    const int64_t i02_divisor = ne12 / ne02;
1488    const int64_t i03_divisor = ne13 / ne03;
1489
1490    const size_t src0_ts = ggml_type_size(src0->type);
1491    const size_t src0_bs = ggml_blck_size(src0->type);
1492    const size_t q8_1_ts = sizeof(block_q8_1);
1493    const size_t q8_1_bs = QK8_1;
1494
1495    const bool src0_is_contiguous = ggml_is_contiguous(src0);
1496    const bool src1_is_contiguous = ggml_is_contiguous(src1);
1497
1498    const int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
1499
1500    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
1501    GGML_ASSERT(!(split && ne02 > 1));
1502    GGML_ASSERT(!(split && ne03 > 1));
1503    GGML_ASSERT(!(split && ne02 < ne12));
1504    GGML_ASSERT(!(split && ne03 < ne13));
1505
1506    ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
1507
1508
1509    std::array<float, GGML_CUDA_MAX_DEVICES> tensor_split;
1510    if (split) {
1511        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
1512        tensor_split = buft_ctx->tensor_split;
1513    }
1514
1515    struct dev_data {
1516        int cc;
1517
1518        ggml_cuda_pool_alloc<char>   src0_dd_alloc;
1519        ggml_cuda_pool_alloc<float> src1_ddf_alloc;
1520        ggml_cuda_pool_alloc<char>  src1_ddq_alloc;
1521        ggml_cuda_pool_alloc<float>   dst_dd_alloc;
1522
1523        char  *  src0_dd = nullptr;
1524        float * src1_ddf = nullptr; // float
1525        char  * src1_ddq = nullptr; // q8_1
1526        float *   dst_dd = nullptr;
1527
1528        int64_t  row_low;
1529        int64_t row_high;
1530    };
1531
1532    dev_data dev[GGML_CUDA_MAX_DEVICES];
1533
1534    int used_devices = 0;
1535
1536    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1537        dev[id].cc = ggml_cuda_info().devices[id].cc;
1538
1539        // by default, use all rows
1540        dev[id].row_low  = 0;
1541        dev[id].row_high = ne01;
1542
1543        // for multi GPU, get the row boundaries from tensor split
1544        // and round to mul_mat_q tile sizes
1545        if (split) {
1546            const int64_t rounding = get_row_rounding(tensor_split);
1547
1548            if (id != 0) {
1549                dev[id].row_low  = ne01*tensor_split[id];
1550                if (dev[id].row_low < ne01) {
1551                    dev[id].row_low -= dev[id].row_low % rounding;
1552                }
1553            }
1554
1555            if (id != ggml_backend_cuda_get_device_count() - 1) {
1556                dev[id].row_high  = ne01*tensor_split[id + 1];
1557                if (dev[id].row_high < ne01) {
1558                    dev[id].row_high -= dev[id].row_high % rounding;
1559                }
1560            }
1561        }
1562    }
1563
1564    for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1565        if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
1566            continue;
1567        }
1568
1569        used_devices++;
1570
1571        const bool src1_on_device = id == src1_ctx->device;
1572        const bool  dst_on_device = id == dst_ctx->device;
1573
1574        ggml_cuda_set_device(id);
1575        cudaStream_t stream = ctx.stream(id, 0);
1576
1577        if (src0_is_contiguous) {
1578            dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
1579        } else {
1580            // If src0 is not contiguous it will be copied to a temporary buffer.
1581            // This buffer needs to be cleared entirely because multiple regions will function as padding.
1582            const size_t nbytes_data    = ggml_nbytes(src0);
1583            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
1584            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
1585            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
1586        }
1587
1588        // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
1589        if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
1590            GGML_ASSERT(ggml_is_contiguously_allocated(src0));
1591            GGML_ASSERT(!src0->view_src);
1592            const size_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
1593            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
1594            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data, 0, nbytes_padding, stream));
1595        }
1596
1597        if (src1_on_device && src1_is_contiguous) {
1598            dev[id].src1_ddf = (float *) src1->data;
1599        } else {
1600            dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
1601        }
1602
1603        if (quantize_src1) {
1604            size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
1605            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1606                src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
1607            }
1608            dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
1609
1610            if (src1_on_device && src1_is_contiguous) {
1611                quantize_src1(
1612                    dev[id].src1_ddf, nullptr, dev[id].src1_ddq, src0->type, ne10,
1613                    nb11/sizeof(float), nb12/sizeof(float), nb13/sizeof(float),
1614                    src1_padded_col_size, ne11, ne12, ne13, stream);
1615                CUDA_CHECK(cudaGetLastError());
1616            }
1617        }
1618
1619        if (dst_on_device) {
1620            dev[id].dst_dd = (float *) dst->data;
1621        } else {
1622            const size_t size_dst_ddf = split ? (dev[id].row_high - dev[id].row_low)*ne1 : ggml_nelements(dst);
1623            dev[id].dst_dd = dev[id].dst_dd_alloc.alloc(ctx.pool(id), size_dst_ddf);
1624        }
1625    }
1626
1627    // if multiple devices are used they need to wait for the main device
1628    // here an event is recorded that signals that the main device has finished calculating the input data
1629    if (split && used_devices > 1) {
1630        ggml_cuda_set_device(ctx.device);
1631        CUDA_CHECK(cudaEventRecord(src0_extra->events[ctx.device][0], ctx.stream()));
1632    }
1633
1634    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
1635    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
1636        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_CUDA_MAX_STREAMS : 0;
1637        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
1638
1639        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1640            if ((!split && id != ctx.device) || dev[id].row_low == dev[id].row_high) {
1641                continue;
1642            }
1643
1644            const bool src1_on_device = id == src1_ctx->device;
1645            const bool  dst_on_device = id == dst_ctx->device;
1646            const int64_t row_diff = dev[id].row_high - dev[id].row_low;
1647
1648            ggml_cuda_set_device(id);
1649            cudaStream_t stream = ctx.stream(id, is);
1650
1651            // wait for main GPU data if necessary
1652            if (split && (id != ctx.device || is != 0)) {
1653                CUDA_CHECK(cudaStreamWaitEvent(stream, src0_extra->events[ctx.device][0], 0));
1654            }
1655
1656            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
1657                const int64_t i03 = i0 / ne12;
1658                const int64_t i02 = i0 % ne12;
1659
1660                size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
1661                if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1662                    src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
1663                } else {
1664                    src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
1665                }
1666
1667                // for split tensors the data begins at i0 == i0_offset_low
1668                const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
1669                char  *  src0_dd_i =  dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
1670                float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
1671                char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
1672                float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
1673
1674                // the main device memory buffer can be on VRAM scratch, with space for all partial results
1675                // in that case an offset on dst_ddf_i is needed
1676                if (id == ctx.device) {
1677                    dst_dd_i += dev[id].row_low; // offset is 0 if no tensor split
1678                }
1679
1680                // copy src0, src1 to device if necessary
1681                if (src1_is_contiguous) {
1682                    if (id != ctx.device) {
1683                        if (quantize_src1) {
1684                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
1685                            if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1686                                const size_t pitch = ne11*sizeof(block_q8_1_mmq);
1687                                const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
1688                                const size_t height = src1_padded_col_size/(4*QK8_1);
1689                                CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
1690                            } else {
1691                                CUDA_CHECK(cudaMemcpyPeerAsync(
1692                                    src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
1693                            }
1694                        } else {
1695                            float * src1_ddf_i_source = (float *) src1->data;
1696                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
1697                            CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddf_i, id, src1_ddf_i_source, ctx.device,
1698                                                            src1_ncols*ne10*sizeof(float), stream));
1699                        }
1700                    }
1701                } else if (src1_on_device && !src1_is_contiguous) {
1702                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
1703                                src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
1704                } else {
1705                    GGML_ABORT("fatal error");
1706                }
1707
1708                if (quantize_src1 && !src1_is_contiguous) {
1709                    quantize_src1(
1710                        src1_ddf_i, nullptr, src1_ddq_i, src0->type, ne10, ne10, ne11*ne10, ne12*ne11*ne10,
1711                        src1_padded_col_size, src1_ncols, 1, 1, stream);
1712                    CUDA_CHECK(cudaGetLastError());
1713                }
1714
1715                if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
1716                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
1717                        src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
1718                }
1719
1720                // do the computation
1721                op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
1722                    dev[id].row_low, dev[id].row_high, src1_ncols, src1_padded_col_size, stream);
1723                CUDA_CHECK(cudaGetLastError());
1724
1725                // copy dst to host or other device if necessary
1726                if (!dst_on_device) {
1727                    void * dst_off_device = dst->data;
1728                    if (split) {
1729                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
1730                        // dst is NOT transposed.
1731                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
1732                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
1733                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
1734                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1735                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
1736                        dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
1737                        CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
1738                            dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
1739                    } else {
1740                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
1741                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
1742                        dhf_dst_i += src1_col_0*ne0;
1743                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_dd_i, src1_ncols*ne0*sizeof(float), cudaMemcpyDeviceToDevice, stream));
1744                    }
1745                }
1746
1747                // add event for the main device to wait on until other device is done
1748                if (split && (id != ctx.device || is != 0)) {
1749                    CUDA_CHECK(cudaEventRecord(src0_extra->events[id][is], stream));
1750                }
1751            }
1752        }
1753    }
1754
1755    // main device waits for all other devices to be finished
1756    if (split && ggml_backend_cuda_get_device_count() > 1) {
1757        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
1758        is_max = is_max <= GGML_CUDA_MAX_STREAMS ? is_max : GGML_CUDA_MAX_STREAMS;
1759
1760        ggml_cuda_set_device(ctx.device);
1761        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1762            if (dev[id].row_low == dev[id].row_high) {
1763                continue;
1764            }
1765            for (int64_t is = 0; is < is_max; ++is) {
1766                CUDA_CHECK(cudaStreamWaitEvent(ctx.stream(), src0_extra->events[id][is], 0));
1767            }
1768        }
1769    }
1770}
1771
1772static __global__ void k_compute_batched_ptrs(
1773        const void * src0_as_f16, const void * src1_as_f16, char * dst,
1774        const void ** ptrs_src, void ** ptrs_dst,
1775        int64_t ne12, int64_t ne13,
1776        int64_t ne23,
1777        size_t  nb02, size_t  nb03,
1778        size_t  nb12, size_t  nb13,
1779        size_t  nbd2, size_t  nbd3,
1780        int64_t r2,   int64_t r3) {
1781    const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
1782    const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
1783
1784    if (i13 >= ne13 || i12 >= ne12) {
1785        return;
1786    }
1787
1788    const int64_t i03 = i13 / r3;
1789    const int64_t i02 = i12 / r2;
1790
1791    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
1792    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
1793    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
1794}
1795
1796// Type traits for mapping ggml types to CUDA/cuBLAS types
1797template<ggml_type T>
1798struct batched_mul_mat_traits;
1799
1800template<>
1801struct batched_mul_mat_traits<GGML_TYPE_F32> {
1802    using cuda_type = float;
1803    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1804    static inline const cudaDataType_t data_type = CUDA_R_32F;
1805    static inline const ggml_type ggml_type_val = GGML_TYPE_F32;
1806    static inline const float alpha = 1.0f;
1807    static inline const float beta = 0.0f;
1808    static inline const void* get_alpha() { static const float val = alpha; return &val; }
1809    static inline const void* get_beta() { static const float val = beta; return &val; }
1810    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp32_nc_cuda(src_type); }
1811};
1812
1813template<>
1814struct batched_mul_mat_traits<GGML_TYPE_BF16> {
1815    using cuda_type = nv_bfloat16;
1816    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
1817    static inline const cudaDataType_t data_type = CUDA_R_16BF;
1818    static inline const ggml_type ggml_type_val = GGML_TYPE_BF16;
1819    static inline const float alpha = 1.0f;
1820    static inline const float beta = 0.0f;
1821    static inline const void* get_alpha() { static const float val = alpha; return &val; }
1822    static inline const void* get_beta() { static const float val = beta; return &val; }
1823    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_bf16_nc_cuda(src_type); }
1824};
1825
1826template<>
1827struct batched_mul_mat_traits<GGML_TYPE_F16> {
1828    using cuda_type = half;
1829    static inline const cublasComputeType_t compute_type = CUBLAS_COMPUTE_16F;
1830    static inline const cudaDataType_t data_type = CUDA_R_16F;
1831    static inline const ggml_type ggml_type_val = GGML_TYPE_F16;
1832    static inline const half alpha = 1.0;
1833    static inline const half beta = 0.0;
1834    static inline const void* get_alpha() { static const half val = alpha; return &val; }
1835    static inline const void* get_beta() { static const half val = beta; return &val; }
1836    static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
1837};
1838
1839template<ggml_type src0_type>
1840static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
1841    using traits = batched_mul_mat_traits<src0_type>;
1842    using cuda_t = typename traits::cuda_type;
1843
1844    GGML_ASSERT(!ggml_is_transposed(src0));
1845    GGML_ASSERT(!ggml_is_transposed(src1));
1846    GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft));
1847    GGML_ASSERT(src0->type == src0_type);
1848    GGML_ASSERT(ggml_is_contiguous(dst));
1849
1850    // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
1851    // As long as dst is contiguous this does not matter though.
1852
1853    GGML_TENSOR_BINARY_OP_LOCALS
1854
1855    const int64_t ne_dst = ggml_nelements(dst);
1856    cudaStream_t main_stream = ctx.stream();
1857    CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
1858
1859    float * dst_ddf = (float *) dst->data;
1860    const size_t ts_src1 = ggml_type_size(src1->type);
1861    GGML_ASSERT(nb10 == ts_src1);
1862    int64_t s11 = nb11 / ts_src1;
1863    int64_t s12 = nb12 / ts_src1;
1864    int64_t s13 = nb13 / ts_src1;
1865
1866    const cuda_t * src0_ptr = nullptr;
1867    const cuda_t * src1_ptr = nullptr;
1868
1869    ggml_cuda_pool_alloc<cuda_t> src0_alloc(ctx.pool());
1870    ggml_cuda_pool_alloc<cuda_t> src1_alloc(ctx.pool());
1871
1872    bool is_src0_cont_2 = ggml_is_contiguous_2(src0);
1873    bool is_src1_cont_2 = ggml_is_contiguous_2(src1);
1874
1875    // Handle src0
1876    src0_ptr = (const cuda_t *) src0->data;
1877
1878    // Handle src1 - convert if necessary
1879    if (src1->type == src0_type) {
1880        src1_ptr = (const cuda_t *) src1->data;
1881    } else {
1882        // Convert src1 to target type using traits conversion functions
1883        const int64_t ne_src1 = ggml_nelements(src1);
1884        src1_alloc.alloc(ne_src1);
1885
1886        const auto convert_func = traits::get_nc_converter(src1->type);
1887        GGML_ASSERT(convert_func != nullptr);
1888        convert_func(src1->data, src1_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
1889        src1_ptr = src1_alloc.get();
1890        s11 = ne10;
1891        s12 = ne11*s11;
1892        s13 = ne12*s12;
1893
1894        is_src1_cont_2 = true;
1895    }
1896
1897    // Setup destination buffer
1898    ggml_cuda_pool_alloc<cuda_t> dst_temp(ctx.pool());
1899    char * dst_t;
1900    size_t nbd2 = dst->nb[2];
1901    size_t nbd3 = dst->nb[3];
1902
1903    cublasComputeType_t cu_compute_type = traits::compute_type;
1904    cudaDataType_t cu_data_type = traits::data_type;
1905    cudaDataType_t cu_data_type_a = traits::data_type;
1906    cudaDataType_t cu_data_type_b = traits::data_type;
1907    const void * alpha = traits::get_alpha();
1908    const void * beta = traits::get_beta();
1909    const float alpha_f32 = 1.0f;
1910    const float beta_f32 = 0.0f;
1911
1912    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
1913        if constexpr (src0_type == GGML_TYPE_F32) {
1914            dst_t = (char *) dst_ddf;  // Direct F32 output
1915        } else {
1916            dst_t = (char *) dst_temp.alloc(ne_dst);
1917            nbd2 /= sizeof(float) / sizeof(cuda_t);
1918            nbd3 /= sizeof(float) / sizeof(cuda_t);
1919        }
1920    } else {
1921        dst_t = (char *) dst_ddf;
1922        cu_compute_type = CUBLAS_COMPUTE_32F;
1923        cu_data_type = CUDA_R_32F;
1924        alpha = &alpha_f32;
1925        beta = &beta_f32;
1926    }
1927
1928    int id = ggml_cuda_get_device();
1929    const int cc = ggml_cuda_info().devices[id].cc;
1930    if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
1931        cu_compute_type = CUBLAS_COMPUTE_32F;
1932        alpha = &alpha_f32;
1933        beta = &beta_f32;
1934    }
1935
1936    GGML_ASSERT(ne12 % ne02 == 0);
1937    GGML_ASSERT(ne13 % ne03 == 0);
1938
1939    // broadcast factors
1940    const int64_t r2 = ne12/ne02;
1941    const int64_t r3 = ne13/ne03;
1942
1943    if (r2 == 1 && r3 == 1 && is_src0_cont_2 && is_src1_cont_2) {
1944        // with a [0, 2, 1, 3] perm. and ne02==1 the matrix strides need to be determined from dim 3:
1945        const int64_t sma = ne02 == 1 ? nb03/nb00 : nb02/nb00;
1946        const int64_t smb = ne12 == 1 ? s13       : s12;
1947
1948        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
1949        // use cublasGemmStridedBatchedEx
1950        CUBLAS_CHECK(
1951        cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1952                ne01, ne11, ne10,
1953                alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma,     // strideA
1954                       src1_ptr, cu_data_type_b, s11,       smb,     // strideB
1955                beta,     dst_t, cu_data_type,   ne0,       ne1*ne0, // strideC
1956                ne12*ne13,
1957                cu_compute_type,
1958                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1959    } else {
1960        // use cublasGemmBatchedEx
1961        const int64_t ne23 = ne12*ne13;
1962
1963        ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
1964        ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
1965
1966        size_t src1_stride_size = sizeof(cuda_t);
1967
1968        const int threads_x = 16;
1969        const int threads_y = 16;
1970        dim3 block_dims(threads_x, threads_y);
1971
1972        dim3 grid_dims(
1973            (ne13 + threads_x - 1) / threads_x,
1974            (ne12 + threads_y - 1) / threads_y
1975        );
1976        k_compute_batched_ptrs<<<grid_dims, block_dims, 0, main_stream>>>(
1977                src0_ptr, src1_ptr, dst_t,
1978                ptrs_src.get(), ptrs_dst.get(),
1979                ne12, ne13,
1980                ne23,
1981                nb02, nb03,
1982                (src1->type == src0_type) ? nb12 : s12*src1_stride_size,
1983                (src1->type == src0_type) ? nb13 : s13*src1_stride_size,
1984                nbd2, nbd3,
1985                r2, r3);
1986
1987        CUDA_CHECK(cudaGetLastError());
1988
1989        CUBLAS_CHECK(
1990        cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
1991                ne01, ne11, ne10,
1992                alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
1993                       (const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
1994                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type,   ne0,
1995                ne23,
1996                cu_compute_type,
1997                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
1998    }
1999
2000    // Convert output back to F32 if needed
2001    if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type != CUDA_R_32F) {
2002        const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(traits::ggml_type_val);
2003        to_fp32_cuda(dst_temp.get(), dst_ddf, ne_dst, main_stream);
2004    }
2005}
2006
2007static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2008    GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F32);
2009
2010    switch (src0->type) {
2011        case GGML_TYPE_F32:
2012            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F32>(ctx, src0, src1, dst);
2013            break;
2014        case GGML_TYPE_BF16:
2015            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_BF16>(ctx, src0, src1, dst);
2016            break;
2017        case GGML_TYPE_F16:
2018            ggml_cuda_mul_mat_batched_cublas_impl<GGML_TYPE_F16>(ctx, src0, src1, dst);
2019            break;
2020        default:
2021            GGML_ABORT("Unsupported type");
2022    }
2023}
2024
2025static bool ggml_cuda_should_fuse_mul_mat(const ggml_tensor * ffn_up,
2026                                          const ggml_tensor * ffn_gate,
2027                                          const ggml_tensor * glu,
2028                                          const ggml_tensor * ffn_up_bias = nullptr,
2029                                          const ggml_tensor * ffn_gate_bias = nullptr) {
2030    const bool has_bias = ffn_up_bias != nullptr || ffn_gate_bias != nullptr;
2031
2032    if (has_bias && (!ffn_up_bias || !ffn_gate_bias)) {
2033        return false;
2034    }
2035
2036    const bool is_mul_mat     = ffn_up->op == GGML_OP_MUL_MAT     && ffn_gate->op == GGML_OP_MUL_MAT     && glu->op == GGML_OP_GLU;
2037    const bool is_mul_mat_id  = ffn_up->op == GGML_OP_MUL_MAT_ID  && ffn_gate->op == GGML_OP_MUL_MAT_ID  && glu->op == GGML_OP_GLU;
2038
2039    GGML_ASSERT(ffn_up && ffn_gate && glu);
2040
2041    if (!is_mul_mat && !is_mul_mat_id) {
2042        return false;
2043    }
2044
2045    const ggml_op expected_bias_op = is_mul_mat ? GGML_OP_ADD : GGML_OP_ADD_ID;
2046
2047    if (has_bias) {
2048        if (ffn_up_bias->op != expected_bias_op || ffn_gate_bias->op != expected_bias_op) {
2049            return false;
2050        }
2051
2052        if (glu->src[0] != ffn_gate_bias || glu->src[1] != ffn_up_bias) {
2053            return false;
2054        }
2055
2056        if (expected_bias_op == GGML_OP_ADD) {
2057            const bool up_has_mul   = ffn_up_bias->src[0] == ffn_up || ffn_up_bias->src[1] == ffn_up;
2058            const bool gate_has_mul = ffn_gate_bias->src[0] == ffn_gate || ffn_gate_bias->src[1] == ffn_gate;
2059            if (!up_has_mul || !gate_has_mul) {
2060                return false;
2061            }
2062        } else { // GGML_OP_ADD_ID
2063            if (ffn_up_bias->src[0] != ffn_up || ffn_gate_bias->src[0] != ffn_gate) {
2064                return false;
2065            }
2066            if (ffn_up_bias->src[2] != ffn_up->src[2] || ffn_gate_bias->src[2] != ffn_gate->src[2]) {
2067                return false;
2068            }
2069        }
2070    } else {
2071        if (glu->src[0] != ffn_gate && glu->src[1] != ffn_up) {
2072            return false;
2073        }
2074    }
2075
2076    if (ffn_up->src[0]->type != ffn_gate->src[0]->type || !ggml_are_same_shape(ffn_up->src[0], ffn_gate->src[0]) ||
2077        !ggml_are_same_stride(ffn_up->src[0], ffn_gate->src[0])) {
2078        return false;
2079    }
2080
2081    if (ffn_up->src[1] != ffn_gate->src[1]) {
2082        return false;
2083    }
2084
2085    if (ffn_up->src[2] && (ffn_up->src[2] != ffn_gate->src[2])) {
2086        return false;
2087    }
2088
2089    static constexpr std::array<ggml_glu_op, 3> valid_glu_ops = { GGML_GLU_OP_SWIGLU, GGML_GLU_OP_GEGLU, GGML_GLU_OP_SWIGLU_OAI };
2090
2091    if (std::find(valid_glu_ops.begin(), valid_glu_ops.end(), ggml_get_glu_op(glu)) == valid_glu_ops.end()) {
2092        return false;
2093    }
2094
2095    if (const bool swapped = ggml_get_op_params_i32(glu, 1); swapped) {
2096        return false;
2097    }
2098
2099    const bool split = ggml_backend_buft_is_cuda_split(ffn_up->src[0]->buffer->buft) ||
2100                       ggml_backend_buft_is_cuda_split(ffn_gate->src[0]->buffer->buft);
2101
2102    //TODO: add support for fusion for split buffers
2103    if (split) {
2104        return false;
2105    }
2106
2107    return true;
2108}
2109
2110static bool ggml_cuda_should_fuse_mul_mat_vec_f(const ggml_tensor * tensor) {
2111    ggml_tensor *       src0 = tensor->src[0];
2112    ggml_tensor *       src1 = tensor->src[1];
2113    const ggml_tensor * dst  = tensor;
2114
2115    const bool is_mul_mat_id = tensor->op == GGML_OP_MUL_MAT_ID;
2116
2117    bool use_mul_mat_vec_f =
2118        (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16) &&
2119        src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2120
2121    const int cc      = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2122    use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, is_mul_mat_id ? src1->ne[2] : src1->ne[1]);
2123
2124    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2125                       ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2126
2127    //TODO: add support for fusion for split buffers
2128    if (split) {
2129        return false;
2130    }
2131
2132    //we only support fusion for ncols_dst = 1
2133    if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2134        return false;
2135    }
2136
2137    if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2138        return false;
2139    }
2140
2141
2142    return use_mul_mat_vec_f;
2143}
2144
2145static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) {
2146    ggml_tensor *       src0 = tensor->src[0];
2147    ggml_tensor *       src1 = tensor->src[1];
2148    const ggml_tensor * dst  = tensor;
2149
2150    const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE &&
2151                                   ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) &&
2152                                   src0->view_src;
2153
2154    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 &&
2155                             dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2156
2157    // fusion is not universally faster on Pascal
2158    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2159    if (cc <= GGML_CUDA_CC_PASCAL) {
2160        return false;
2161    }
2162    //we only support fusion for ncols_dst = 1
2163    if (tensor->op == GGML_OP_MUL_MAT && dst->ne[1] != 1) {
2164        return false;
2165    }
2166
2167    if (tensor->op == GGML_OP_MUL_MAT_ID && dst->ne[2] != 1) {
2168        return false;
2169    }
2170
2171
2172    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft) ||
2173                       ggml_backend_buft_is_cuda_split(src1->buffer->buft);
2174
2175    //TODO: add support for fusion for split buffers
2176    if (split) {
2177        return false;
2178    }
2179
2180    return use_mul_mat_vec_q;
2181}
2182
2183static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2184    const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft);
2185
2186    // If src0 is a temporary compute buffer it may have some padding that needs to be cleared for mul_mat_vec_q or mul_mat_q.
2187    // But if src0 is also a view of another tensor then this cannot be done safely because it may overwrite valid tensor data.
2188    // Therefore, in such cases use cuBLAS.
2189    const bool bad_padding_clear = ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE
2190        && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
2191
2192    bool use_mul_mat_vec_f = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
2193        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2194    bool use_mul_mat_f     = !ggml_is_quantized(src0->type)
2195        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2196    bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
2197        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
2198        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
2199    bool use_mul_mat_q     = ggml_is_quantized(src0->type) && !bad_padding_clear
2200        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
2201
2202    bool any_gpus_with_slow_fp16 = false;
2203
2204    if (split) {
2205        ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
2206        auto & tensor_split = buft_ctx->tensor_split;
2207        for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
2208            // skip devices that are not going to do any work:
2209            if (tensor_split[id] >= (id + 1 < ggml_backend_cuda_get_device_count() ? tensor_split[id + 1] : 1.0f)) {
2210                continue;
2211            }
2212
2213            const int cc            = ggml_cuda_info().devices[id].cc;
2214            const int warp_size     = ggml_cuda_info().devices[id].warp_size;
2215            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2216            use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2217            use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2218            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
2219        }
2220    } else {
2221        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
2222        const int warp_size     = ggml_cuda_info().devices[ctx.device].warp_size;
2223        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0);
2224        use_mul_mat_f           = use_mul_mat_f             && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false);
2225        use_mul_mat_vec_f       = use_mul_mat_vec_f         && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]);
2226        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
2227    }
2228
2229    // debug helpers
2230    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
2231    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
2232    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
2233    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
2234    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
2235    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
2236
2237    //TODO update for generic tensor parallelism
2238    const int cc                 = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2239    bool use_batched_cublas_f16  = src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16);
2240    bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc);
2241    bool use_batched_cublas_f32  = src0->type == GGML_TYPE_F32;
2242
2243    if (!split && use_mul_mat_vec_f) {
2244        // the custom F16 vector kernel can be used over batched cuBLAS GEMM
2245        // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
2246        ggml_cuda_mul_mat_vec_f(ctx, src0, src1, nullptr, dst);
2247    } else if (!split && use_mul_mat_f) {
2248        ggml_cuda_mul_mat_f(ctx, src0, src1, nullptr, dst);
2249    } else if (!split && use_mul_mat_vec_q) {
2250        ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
2251    } else if (!split && use_mul_mat_q) {
2252        ggml_cuda_mul_mat_q(ctx, src0, src1, nullptr, dst);
2253    } else if (!split && (use_batched_cublas_f16 || use_batched_cublas_bf16 || use_batched_cublas_f32)
2254        && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
2255        // general KQ + KQV multi-batch without FlashAttention
2256        ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
2257    } else if (use_mul_mat_vec_f) {
2258        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_f, nullptr);
2259    } else if (use_mul_mat_vec_q) {
2260        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
2261    } else if (use_mul_mat_q) {
2262        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
2263    } else {
2264        ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
2265    }
2266}
2267
2268static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2269    const ggml_tensor * src0 = dst->src[0];
2270    const ggml_tensor * src1 = dst->src[1];
2271    const ggml_tensor * ids  = dst->src[2];
2272
2273    GGML_ASSERT(src1->type == GGML_TYPE_F32);
2274    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
2275    GGML_ASSERT(!ggml_backend_buft_is_cuda_split(src0->buffer->buft) && "mul_mat_id does not support split buffers");
2276
2277    GGML_TENSOR_BINARY_OP_LOCALS
2278
2279    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
2280
2281    if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
2282        static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
2283        if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
2284            if (ggml_is_quantized(src0->type)) {
2285                if (ne2 <= 4) {
2286                    ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
2287                    return;
2288                }
2289            } else {
2290                if (GGML_CUDA_CC_IS_AMD(cc)) {
2291                    ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
2292                    return;
2293                }
2294            }
2295        }
2296
2297        if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
2298            ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
2299            return;
2300        }
2301
2302        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src0->nb, src1->ne[2], /*mul_mat_id=*/true)) {
2303            ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
2304            return;
2305        }
2306    }
2307
2308    cudaStream_t stream = ctx.stream();
2309
2310    GGML_ASSERT(nb12 % nb11 == 0);
2311    GGML_ASSERT(nb2  % nb1  == 0);
2312
2313    const ggml_type type_src1_sorted = (src0->type == GGML_TYPE_F16 && !fast_fp16_hardware_available(cc))
2314        || ggml_is_quantized(src0->type) ? GGML_TYPE_F32 : src0->type;
2315    const ggml_type type_dst_sorted  = GGML_TYPE_F32;
2316    const size_t ts_src1_sorted = ggml_type_size(type_src1_sorted);
2317    const size_t ts_dst_sorted  = ggml_type_size(type_dst_sorted);
2318
2319    const int64_t n_expert_used = ids->ne[0];
2320    const int64_t ne_get_rows = ne12 * n_expert_used;
2321
2322    std::vector<int32_t> ids_to_sorted_host;
2323    ids_to_sorted_host.reserve(2*ne_get_rows);
2324    std::vector<int32_t> ids_from_sorted_host(ne_get_rows);
2325
2326    ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool(), 2*ne_get_rows);
2327
2328    std::vector<int32_t> tokens_per_expert(ne02);
2329
2330    ggml_cuda_pool_alloc<char> src1_sorted(ctx.pool(), ne12*n_expert_used*ne10*ts_src1_sorted);
2331    ggml_cuda_pool_alloc<char>  dst_sorted(ctx.pool(), ne2 *n_expert_used* ne0*ts_dst_sorted);
2332
2333    std::vector<char> ids_host(ggml_nbytes(ids));
2334    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
2335    CUDA_CHECK(cudaStreamSynchronize(stream));
2336
2337    for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
2338        for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
2339            for (int64_t iex = 0; iex < n_expert_used; ++iex) {
2340                const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
2341                assert(expert_to_use >= 0 && expert_to_use < ne02);
2342                if (expert_to_use == i02) {
2343                    ids_from_sorted_host[i12*n_expert_used + iex] = ids_to_sorted_host.size();
2344                    ids_to_sorted_host.push_back(i12*ne11 + iex % ne11);
2345                    tokens_per_expert[i02]++;
2346                    break;
2347                }
2348            }
2349        }
2350    }
2351    GGML_ASSERT(ids_to_sorted_host.size() == size_t(ne_get_rows));
2352
2353    ids_to_sorted_host.insert(ids_to_sorted_host.end(), ids_from_sorted_host.begin(), ids_from_sorted_host.end());
2354
2355    CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_to_sorted_host.data(), 2*ne_get_rows*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
2356    CUDA_CHECK(cudaStreamSynchronize(stream));
2357
2358    const int32_t * ids_to_sorted   = ids_buf_dev.ptr + 0*ne_get_rows;
2359    const int32_t * ids_from_sorted = ids_buf_dev.ptr + 1*ne_get_rows;
2360
2361    get_rows_cuda(src1->data, src1->type, ids_to_sorted, src1_sorted.ptr, type_src1_sorted,
2362        ne10, nb11, nb12, nb13,
2363        ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
2364        ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, ne_get_rows*ne10*ts_src1_sorted, stream);
2365    CUDA_CHECK(cudaGetLastError());
2366
2367    char * src1_data_cur = (char *) src1_sorted.ptr;
2368    char *  dst_data_cur = (char *)  dst_sorted.ptr;
2369    for (int64_t i02 = 0; i02 < ne02; ++i02) {
2370        if (tokens_per_expert[i02] == 0) {
2371            continue;
2372        }
2373
2374        ggml_tensor src0_slice = *src0;
2375        src0_slice.ne[2]    = 1;
2376        src0_slice.nb[3]    = src0_slice.nb[2];
2377        src0_slice.op       = GGML_OP_VIEW;
2378        src0_slice.view_src = dst->src[0]; // non-const pointer to src0
2379        src0_slice.data     = (char *) src0->data + i02*nb02;
2380
2381        ggml_tensor src1_slice;
2382        memset(&src1_slice, 0, sizeof(src1_slice));
2383        src1_slice.buffer = src1->buffer;
2384        src1_slice.type   = type_src1_sorted;
2385        src1_slice.ne[0]  = ne10;
2386        src1_slice.ne[1]  = tokens_per_expert[i02];
2387        src1_slice.ne[2]  = 1;
2388        src1_slice.ne[3]  = 1;
2389        src1_slice.nb[0]  = ts_src1_sorted;
2390        src1_slice.nb[1]  = src1_slice.ne[0] * src1_slice.nb[0];
2391        src1_slice.nb[2]  = src1_slice.ne[1] * src1_slice.nb[1];
2392        src1_slice.nb[3]  = src1_slice.ne[2] * src1_slice.nb[2];
2393        src1_slice.data   = src1_data_cur;
2394
2395        ggml_tensor dst_slice;
2396        memset(&dst_slice, 0, sizeof(dst_slice));
2397        dst_slice.buffer = dst->buffer;
2398        dst_slice.type   = type_dst_sorted;
2399        dst_slice.ne[0]  = ne0;
2400        dst_slice.ne[1]  = tokens_per_expert[i02];
2401        dst_slice.ne[2]  = 1;
2402        dst_slice.ne[3]  = 1;
2403        dst_slice.nb[0]  = ts_dst_sorted;
2404        dst_slice.nb[1]  = dst_slice.ne[0] * dst_slice.nb[0];
2405        dst_slice.nb[2]  = dst_slice.ne[1] * dst_slice.nb[1];
2406        dst_slice.nb[3]  = dst_slice.ne[2] * dst_slice.nb[2];
2407        dst_slice.data   = dst_data_cur;
2408
2409        ggml_cuda_mul_mat(ctx, &src0_slice, &src1_slice, &dst_slice);
2410        CUDA_CHECK(cudaGetLastError());
2411
2412        src1_data_cur += src1_slice.nb[2];
2413        dst_data_cur  +=  dst_slice.nb[2];
2414    }
2415
2416    get_rows_cuda(dst_sorted.ptr, type_dst_sorted, ids_from_sorted, dst->data, dst->type,
2417        ne0, ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted, ne_get_rows*ne0*ts_dst_sorted,
2418        ne_get_rows, 1, 1, sizeof(int32_t), ne_get_rows*sizeof(int32_t), ne_get_rows*sizeof(int32_t),
2419        nb1, nb2, nb3, stream);
2420}
2421
2422static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) {
2423    // why is this here instead of mul_mat?
2424    if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) {
2425        ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device);
2426    }
2427
2428    switch (dst->op) {
2429        case GGML_OP_ARGMAX:
2430            ggml_cuda_argmax(ctx, dst);
2431            break;
2432        case GGML_OP_COUNT_EQUAL:
2433            ggml_cuda_count_equal(ctx, dst);
2434            break;
2435        case GGML_OP_REPEAT:
2436            ggml_cuda_op_repeat(ctx, dst);
2437            break;
2438        case GGML_OP_REPEAT_BACK:
2439            ggml_cuda_op_repeat_back(ctx, dst);
2440            break;
2441        case GGML_OP_GET_ROWS:
2442            ggml_cuda_op_get_rows(ctx, dst);
2443            break;
2444        case GGML_OP_GET_ROWS_BACK:
2445            ggml_cuda_op_get_rows_back(ctx, dst);
2446            break;
2447        case GGML_OP_SET_ROWS:
2448            ggml_cuda_op_set_rows(ctx, dst);
2449            break;
2450        case GGML_OP_SET:
2451            ggml_cuda_op_set(ctx, dst);
2452            break;
2453        case GGML_OP_DUP:
2454            ggml_cuda_dup(ctx, dst);
2455            break;
2456        case GGML_OP_CPY:
2457            ggml_cuda_cpy(ctx, dst->src[0], dst->src[1]);
2458            break;
2459        case GGML_OP_CONT:
2460            ggml_cuda_dup(ctx, dst);
2461            break;
2462        case GGML_OP_ADD:
2463        case GGML_OP_ADD1: // TODO: more efficient implementation
2464            ggml_cuda_op_add(ctx, dst);
2465            break;
2466        case GGML_OP_ADD_ID:
2467            ggml_cuda_op_add_id(ctx, dst);
2468            break;
2469        case GGML_OP_SUB:
2470            ggml_cuda_op_sub(ctx, dst);
2471            break;
2472        case GGML_OP_ACC:
2473            ggml_cuda_op_acc(ctx, dst);
2474            break;
2475        case GGML_OP_MUL:
2476            ggml_cuda_op_mul(ctx, dst);
2477            break;
2478        case GGML_OP_DIV:
2479            ggml_cuda_op_div(ctx, dst);
2480            break;
2481        case GGML_OP_UNARY:
2482            switch (ggml_get_unary_op(dst)) {
2483                case GGML_UNARY_OP_ABS:
2484                    ggml_cuda_op_abs(ctx, dst);
2485                    break;
2486                case GGML_UNARY_OP_SGN:
2487                    ggml_cuda_op_sgn(ctx, dst);
2488                    break;
2489                case GGML_UNARY_OP_NEG:
2490                    ggml_cuda_op_neg(ctx, dst);
2491                    break;
2492                case GGML_UNARY_OP_STEP:
2493                    ggml_cuda_op_step(ctx, dst);
2494                    break;
2495                case GGML_UNARY_OP_GELU:
2496                    ggml_cuda_op_gelu(ctx, dst);
2497                    break;
2498                case GGML_UNARY_OP_SILU:
2499                    ggml_cuda_op_silu(ctx, dst);
2500                    break;
2501                case GGML_UNARY_OP_GELU_ERF:
2502                    ggml_cuda_op_gelu_erf(ctx, dst);
2503                    break;
2504                case GGML_UNARY_OP_GELU_QUICK:
2505                    ggml_cuda_op_gelu_quick(ctx, dst);
2506                    break;
2507                case GGML_UNARY_OP_TANH:
2508                    ggml_cuda_op_tanh(ctx, dst);
2509                    break;
2510                case GGML_UNARY_OP_RELU:
2511                    ggml_cuda_op_relu(ctx, dst);
2512                    break;
2513                case GGML_UNARY_OP_SIGMOID:
2514                    ggml_cuda_op_sigmoid(ctx, dst);
2515                    break;
2516                case GGML_UNARY_OP_HARDSIGMOID:
2517                    ggml_cuda_op_hardsigmoid(ctx, dst);
2518                    break;
2519                case GGML_UNARY_OP_HARDSWISH:
2520                    ggml_cuda_op_hardswish(ctx, dst);
2521                    break;
2522                case GGML_UNARY_OP_EXP:
2523                    ggml_cuda_op_exp(ctx, dst);
2524                    break;
2525                case GGML_UNARY_OP_ELU:
2526                    ggml_cuda_op_elu(ctx, dst);
2527                    break;
2528                case GGML_UNARY_OP_XIELU:
2529                    ggml_cuda_op_xielu(ctx, dst);
2530                    break;
2531                case GGML_UNARY_OP_FLOOR:
2532                    ggml_cuda_op_floor(ctx, dst);
2533                    break;
2534                case GGML_UNARY_OP_CEIL:
2535                    ggml_cuda_op_ceil(ctx, dst);
2536                    break;
2537                case GGML_UNARY_OP_ROUND:
2538                    ggml_cuda_op_round(ctx, dst);
2539                    break;
2540                case GGML_UNARY_OP_TRUNC:
2541                    ggml_cuda_op_trunc(ctx, dst);
2542                    break;
2543                case GGML_UNARY_OP_EXPM1:
2544                    ggml_cuda_op_expm1(ctx, dst);
2545                    break;
2546                case GGML_UNARY_OP_SOFTPLUS:
2547                    ggml_cuda_op_softplus(ctx, dst);
2548                    break;
2549                default:
2550                    return false;
2551            }
2552            break;
2553        case GGML_OP_GLU:
2554            switch (ggml_get_glu_op(dst)) {
2555                case GGML_GLU_OP_REGLU:
2556                    ggml_cuda_op_reglu(ctx, dst);
2557                    break;
2558                case GGML_GLU_OP_GEGLU:
2559                    ggml_cuda_op_geglu(ctx, dst);
2560                    break;
2561                case GGML_GLU_OP_SWIGLU:
2562                    ggml_cuda_op_swiglu(ctx, dst);
2563                    break;
2564                case GGML_GLU_OP_SWIGLU_OAI:
2565                    ggml_cuda_op_swiglu_oai(ctx, dst);
2566                    break;
2567                case GGML_GLU_OP_GEGLU_ERF:
2568                    ggml_cuda_op_geglu_erf(ctx, dst);
2569                    break;
2570                case GGML_GLU_OP_GEGLU_QUICK:
2571                    ggml_cuda_op_geglu_quick(ctx, dst);
2572                    break;
2573                default:
2574                    return false;
2575            }
2576            break;
2577        case GGML_OP_NORM:
2578            ggml_cuda_op_norm(ctx, dst);
2579            break;
2580        case GGML_OP_GROUP_NORM:
2581            ggml_cuda_op_group_norm(ctx, dst);
2582            break;
2583        case GGML_OP_L2_NORM:
2584            ggml_cuda_op_l2_norm(ctx, dst);
2585            break;
2586        case GGML_OP_CONCAT:
2587            ggml_cuda_op_concat(ctx, dst);
2588            break;
2589        case GGML_OP_UPSCALE:
2590            ggml_cuda_op_upscale(ctx, dst);
2591            break;
2592        case GGML_OP_PAD:
2593            ggml_cuda_op_pad(ctx, dst);
2594            break;
2595        case GGML_OP_PAD_REFLECT_1D:
2596            ggml_cuda_op_pad_reflect_1d(ctx, dst);
2597            break;
2598        case GGML_OP_ARANGE:
2599            ggml_cuda_op_arange(ctx, dst);
2600            break;
2601        case GGML_OP_TIMESTEP_EMBEDDING:
2602            ggml_cuda_op_timestep_embedding(ctx, dst);
2603            break;
2604        case GGML_OP_LEAKY_RELU:
2605            ggml_cuda_op_leaky_relu(ctx, dst);
2606            break;
2607        case GGML_OP_SILU_BACK:
2608            ggml_cuda_op_silu_back(ctx, dst);
2609            break;
2610        case GGML_OP_RMS_NORM:
2611            ggml_cuda_op_rms_norm(ctx, dst);
2612            break;
2613        case GGML_OP_RMS_NORM_BACK:
2614            ggml_cuda_op_rms_norm_back(ctx, dst);
2615            break;
2616        case GGML_OP_MUL_MAT:
2617            ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
2618            break;
2619        case GGML_OP_MUL_MAT_ID:
2620            ggml_cuda_mul_mat_id(ctx, dst);
2621            break;
2622        case GGML_OP_OUT_PROD:
2623            ggml_cuda_out_prod(ctx, dst);
2624            break;
2625        case GGML_OP_SCALE:
2626            ggml_cuda_op_scale(ctx, dst);
2627            break;
2628        case GGML_OP_SQR:
2629            ggml_cuda_op_sqr(ctx, dst);
2630            break;
2631        case GGML_OP_SQRT:
2632            ggml_cuda_op_sqrt(ctx, dst);
2633            break;
2634        case GGML_OP_SIN:
2635            ggml_cuda_op_sin(ctx, dst);
2636            break;
2637        case GGML_OP_COS:
2638            ggml_cuda_op_cos(ctx, dst);
2639            break;
2640        case GGML_OP_CLAMP:
2641            ggml_cuda_op_clamp(ctx, dst);
2642            break;
2643        case GGML_OP_LOG:
2644            ggml_cuda_op_log(ctx, dst);
2645            break;
2646        case GGML_OP_NONE:
2647        case GGML_OP_RESHAPE:
2648        case GGML_OP_VIEW:
2649        case GGML_OP_PERMUTE:
2650        case GGML_OP_TRANSPOSE:
2651                break;
2652        case GGML_OP_DIAG:
2653            ggml_cuda_op_diag(ctx, dst);
2654            break;
2655        case GGML_OP_DIAG_MASK_INF:
2656            ggml_cuda_op_diag_mask_inf(ctx, dst);
2657            break;
2658        case GGML_OP_SOFT_MAX:
2659            ggml_cuda_op_soft_max(ctx, dst);
2660            break;
2661        case GGML_OP_SOFT_MAX_BACK:
2662            ggml_cuda_op_soft_max_back(ctx, dst);
2663            break;
2664        case GGML_OP_ROPE:
2665            ggml_cuda_op_rope(ctx, dst);
2666            break;
2667        case GGML_OP_ROPE_BACK:
2668            ggml_cuda_op_rope_back(ctx, dst);
2669            break;
2670        case GGML_OP_ROLL:
2671            ggml_cuda_op_roll(ctx, dst);
2672            break;
2673        case GGML_OP_IM2COL:
2674            ggml_cuda_op_im2col(ctx, dst);
2675            break;
2676        case GGML_OP_IM2COL_3D:
2677            ggml_cuda_op_im2col_3d(ctx, dst);
2678            break;
2679        case GGML_OP_CONV_2D:
2680            ggml_cuda_op_conv2d(ctx, dst);
2681            break;
2682        case GGML_OP_CONV_2D_DW:
2683            ggml_cuda_op_conv2d_dw(ctx, dst);
2684            break;
2685        case GGML_OP_CONV_TRANSPOSE_2D:
2686            ggml_cuda_conv_2d_transpose_p0(ctx, dst);
2687            break;
2688        case GGML_OP_CONV_TRANSPOSE_1D:
2689            ggml_cuda_op_conv_transpose_1d(ctx,dst);
2690            break;
2691        case GGML_OP_POOL_2D:
2692            ggml_cuda_op_pool2d(ctx, dst);
2693            break;
2694        case GGML_OP_SUM:
2695            ggml_cuda_op_sum(ctx, dst);
2696            break;
2697        case GGML_OP_CUMSUM:
2698            ggml_cuda_op_cumsum(ctx, dst);
2699            break;
2700        case GGML_OP_SUM_ROWS:
2701            ggml_cuda_op_sum_rows(ctx, dst);
2702            break;
2703        case GGML_OP_MEAN:
2704            ggml_cuda_op_mean(ctx, dst);
2705            break;
2706        case GGML_OP_SSM_CONV:
2707            ggml_cuda_op_ssm_conv(ctx, dst);
2708            break;
2709        case GGML_OP_SSM_SCAN:
2710            ggml_cuda_op_ssm_scan(ctx, dst);
2711            break;
2712        case GGML_OP_TOP_K:
2713            ggml_cuda_op_top_k(ctx, dst);
2714            break;
2715        case GGML_OP_ARGSORT:
2716            ggml_cuda_op_argsort(ctx, dst);
2717            break;
2718        case GGML_OP_FLASH_ATTN_EXT:
2719            ggml_cuda_flash_attn_ext(ctx, dst);
2720            break;
2721        case GGML_OP_CROSS_ENTROPY_LOSS:
2722            ggml_cuda_cross_entropy_loss(ctx, dst);
2723            break;
2724        case GGML_OP_TRI:
2725            ggml_cuda_op_tri(ctx, dst);
2726            break;
2727        case GGML_OP_RWKV_WKV6:
2728            ggml_cuda_op_rwkv_wkv6(ctx, dst);
2729            break;
2730        case GGML_OP_GATED_LINEAR_ATTN:
2731            ggml_cuda_op_gated_linear_attn(ctx, dst);
2732            break;
2733        case GGML_OP_RWKV_WKV7:
2734            ggml_cuda_op_rwkv_wkv7(ctx, dst);
2735            break;
2736        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
2737            ggml_cuda_cross_entropy_loss_back(ctx, dst);
2738            break;
2739        case GGML_OP_OPT_STEP_ADAMW:
2740            ggml_cuda_opt_step_adamw(ctx, dst);
2741            break;
2742        case GGML_OP_OPT_STEP_SGD:
2743            ggml_cuda_opt_step_sgd(ctx, dst);
2744            break;
2745        case GGML_OP_SOLVE_TRI:
2746            ggml_cuda_op_solve_tri(ctx, dst);
2747            break;
2748        case GGML_OP_FILL:
2749            ggml_cuda_op_fill(ctx, dst);
2750            break;
2751        default:
2752            return false;
2753    }
2754
2755    cudaError_t err = cudaGetLastError();
2756    if (err != cudaSuccess) {
2757        GGML_LOG_ERROR("%s: %s failed\n", __func__, ggml_op_desc(dst));
2758        CUDA_CHECK(err);
2759    }
2760
2761    return true;
2762}
2763
2764////////////////////////////////////////////////////////////////////////////////
2765
2766// backend
2767
2768static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) {
2769    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2770
2771    return cuda_ctx->name.c_str();
2772}
2773
2774static void ggml_backend_cuda_free(ggml_backend_t backend) {
2775    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2776
2777    delete cuda_ctx;
2778    delete backend;
2779}
2780
2781static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
2782    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2783    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2784
2785    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
2786
2787    CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream()));
2788}
2789
2790static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
2791    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2792    ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
2793
2794    GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type");
2795
2796    CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream()));
2797}
2798
2799static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
2800    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2801    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2802
2803    if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) {
2804        return false;
2805    }
2806
2807    if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) {
2808        return false;
2809    }
2810
2811    // device -> device copy
2812    ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context;
2813    ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context;
2814
2815    ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context;
2816    ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context;
2817
2818    if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
2819#ifndef NDEBUG
2820        GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
2821#endif
2822        return false;
2823    }
2824
2825    if (backend_src != backend_dst) {
2826        // copy on src stream
2827        if (cuda_ctx_src->device == cuda_ctx_dst->device) {
2828            CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
2829        } else {
2830#ifdef GGML_CUDA_NO_PEER_COPY
2831            return false;
2832#else
2833            CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream()));
2834#endif
2835        }
2836
2837        // record event on src stream after the copy
2838        if (!cuda_ctx_src->copy_event) {
2839            ggml_cuda_set_device(cuda_ctx_src->device);
2840            CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming));
2841        }
2842
2843        CUDA_CHECK(cudaEventRecord(cuda_ctx_src->copy_event, cuda_ctx_src->stream()));
2844
2845        // wait on dst stream for the copy to complete
2846        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0));
2847    } else {
2848        // src and dst are on the same backend
2849        CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream()));
2850    }
2851    return true;
2852}
2853
2854static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
2855    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
2856
2857    CUDA_CHECK(cudaStreamSynchronize(cuda_ctx->stream()));
2858
2859    GGML_UNUSED(backend);
2860}
2861
2862#ifdef USE_CUDA_GRAPH
2863static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
2864
2865    bool use_cuda_graph = true;
2866    // Loop over nodes in GGML graph to obtain info needed for CUDA graph
2867
2868    const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
2869    const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
2870    const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
2871    const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
2872    const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
2873    const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out";
2874    const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d";
2875
2876    for (int i = 0; i < cgraph->n_nodes; i++) {
2877        ggml_tensor * node = cgraph->nodes[i];
2878
2879        if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2880            continue;
2881        }
2882
2883        if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
2884            use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
2885#ifndef NDEBUG
2886            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
2887#endif
2888        }
2889
2890        if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
2891            use_cuda_graph = false; // This node type is not supported by CUDA graph capture
2892#ifndef NDEBUG
2893            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
2894#endif
2895        }
2896
2897        if (node->op == GGML_OP_ADD &&
2898            node->src[1] && node->src[1]->ne[1] > 1 &&
2899            (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
2900            (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
2901            strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
2902            strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
2903            strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 &&
2904            strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 &&
2905            strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) {
2906            // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
2907            // by means of matching node names. See
2908            // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
2909            // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
2910            // Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
2911            use_cuda_graph = false;
2912#ifndef NDEBUG
2913            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
2914#endif
2915        }
2916
2917        if (!use_cuda_graph) {
2918            break;
2919        }
2920    }
2921
2922    return use_cuda_graph;
2923}
2924
2925static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
2926    memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
2927    props->node_data = node->data;
2928    props->node_op = node->op;
2929    props->node_type = node->type;
2930    props->flags = node->flags;
2931    for (int i = 0; i < GGML_MAX_DIMS; i++) {
2932        props->ne[i] = node->ne[i];
2933        props->nb[i] = node->nb[i];
2934    }
2935    for (int i = 0; i < GGML_MAX_SRC; i++) {
2936        if (!node->src[i]) {
2937            continue;
2938        }
2939
2940        props->src_data[i] = node->src[i]->data;
2941    }
2942    memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
2943}
2944
2945static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
2946    if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
2947        return false;
2948    }
2949
2950    if (node->op != props->node_op) {
2951        return false;
2952    }
2953
2954    if (node->type != props->node_type) {
2955        return false;
2956    }
2957
2958    for (int i = 0; i < GGML_MAX_DIMS; i++) {
2959        if (node->ne[i] != props->ne[i]) {
2960            return false;
2961        }
2962        if (node->nb[i] != props->nb[i]) {
2963            return false;
2964        }
2965    }
2966
2967    if (node->op != GGML_OP_VIEW) {
2968        for (int i = 0; i < GGML_MAX_SRC; i++) {
2969            if (!node->src[i]) {
2970                if (props->src_data[i] != nullptr) {
2971                    return false;
2972                }
2973                continue;
2974            }
2975
2976            if (node->src[i]->data != props->src_data[i]) {
2977                return false;
2978            }
2979        }
2980    }
2981
2982    if (memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
2983        return false;
2984    }
2985
2986    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) {
2987        return false;
2988    }
2989
2990    return true;
2991}
2992
2993static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
2994    return cgraph->nodes[0];
2995}
2996
2997static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
2998    bool res = false;
2999
3000    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
3001    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3002
3003    if (graph->instance == nullptr) {
3004        res = true;
3005    }
3006
3007    // Check if the graph size has changed
3008    if (graph->props.size() != (size_t)cgraph->n_nodes) {
3009        res = true;
3010        graph->props.resize(cgraph->n_nodes);
3011    }
3012
3013    // Loop over nodes in GGML graph to determine if CUDA graph update is required
3014    // and store properties to allow this comparison for the next token
3015    std::unordered_set<ggml_tensor *> seen_node;
3016    std::vector<ggml_tensor *> srcs_extra;
3017    for (int i = 0; i < cgraph->n_nodes; i++) {
3018        bool props_match = true;
3019
3020        seen_node.insert(cgraph->nodes[i]);
3021
3022        if (!res) {
3023            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
3024        }
3025        if (!props_match) {
3026            res = true;
3027        }
3028        ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
3029
3030        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
3031            ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
3032            if (src && seen_node.find(src) == seen_node.end()) {
3033                srcs_extra.push_back(src);
3034            }
3035        }
3036    }
3037
3038    if (graph->extra.size() != (size_t) srcs_extra.size()) {
3039        res = true;
3040        graph->extra.resize(srcs_extra.size());
3041    }
3042
3043    for (size_t i = 0; i < srcs_extra.size(); ++i) {
3044        bool props_match = true;
3045
3046        if (!res) {
3047            props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
3048        }
3049
3050        if (!props_match) {
3051            res = true;
3052        }
3053        ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
3054    }
3055
3056    return res;
3057}
3058
3059static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
3060    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3061
3062#if CUDART_VERSION >= 12000
3063    cudaGraphExecUpdateResultInfo result_info;
3064    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
3065#else
3066    cudaGraphNode_t errorNode;
3067    cudaGraphExecUpdateResult result_info;
3068    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
3069#endif // CUDART_VERSION >= 12000
3070
3071    if (stat == cudaErrorGraphExecUpdateFailure) {
3072#ifndef NDEBUG
3073        GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
3074#endif
3075
3076        // The pre-existing graph exec cannot be updated due to violated constraints
3077        // so instead clear error and re-instantiate
3078        (void)cudaGetLastError();
3079        CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
3080        graph->instance = nullptr;
3081        CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3082    } else {
3083        GGML_ASSERT(stat == cudaSuccess);
3084    }
3085}
3086#endif // USE_CUDA_GRAPH
3087
3088static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
3089                                                const ggml_tensor * view,
3090                                                const ggml_tensor * set_rows) {
3091
3092    if (rope->op != GGML_OP_ROPE || view->op != GGML_OP_VIEW || set_rows->op != GGML_OP_SET_ROWS) {
3093        return false;
3094    }
3095    // ne3 not tested
3096    if (rope->src[0]->ne[3] != 1) {
3097        return false;
3098    }
3099
3100    if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
3101        return false;
3102    }
3103
3104    if (set_rows->src[1]->type != GGML_TYPE_I64) {
3105        return false;
3106    }
3107
3108    // The view should flatten two dims of rope into one dim
3109    if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) {
3110        return false;
3111    }
3112
3113    // Only norm/neox shaders have the fusion code
3114    const int mode = ((const int32_t *) rope->op_params)[2];
3115    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) {
3116        return false;
3117    }
3118
3119    return true;
3120}
3121
3122static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
3123    args.sigmoid         = false;
3124    args.softmax         = false;
3125    args.delayed_softmax = false;
3126    args.prob_bias       = false;
3127    args.norm            = false;
3128
3129    const int      n_nodes = cgraph->n_nodes;
3130    ggml_tensor ** nodes   = cgraph->nodes;
3131
3132    if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
3133        args.softmax = true;
3134    }
3135
3136    if (nodes[node_idx]->op == GGML_OP_UNARY) {
3137        if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
3138            return false;
3139        }
3140        args.sigmoid = true;
3141    }
3142
3143    if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
3144        args.delayed_softmax = true;
3145    }
3146
3147    node_idx++;
3148
3149    if (args.sigmoid || args.softmax) {
3150        // SOFTMAX -> RESHAPE
3151        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
3152                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3153            return false;
3154        }
3155        ggml_tensor * probs_reshaped = nodes[node_idx];
3156        node_idx++;
3157
3158        if (node_idx >= n_nodes) {
3159            return false;
3160        }
3161
3162        // src of bias add is the unreshaped probs (-2 instead of -1)
3163        if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
3164            args.prob_bias = true;
3165            node_idx++;
3166        }
3167        // RESHAPE/ADD -> ARGSORT
3168        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
3169            return false;
3170        }
3171
3172        if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3173            return false;
3174        } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
3175            return false;
3176        }
3177
3178        node_idx++;
3179
3180        // ARGSORT-> VIEW
3181        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3182                nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3183            return false;
3184        }
3185        node_idx++;
3186
3187        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
3188            return false;
3189        }
3190
3191        // GET_ROWS
3192        if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
3193            return false;
3194        }
3195        node_idx++;
3196    } else if (args.delayed_softmax) {
3197        if (node_idx - 2 < 0) {
3198            return false;
3199        }
3200        ggml_tensor * probs_reshaped = nodes[node_idx - 2];
3201
3202        // VIEW->ARGSORT
3203        if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
3204            nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3205            return false;
3206        }
3207        node_idx++;
3208
3209        // GET_ROWS
3210        if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3211                nodes[node_idx]->src[0] != probs_reshaped) {
3212            return false;
3213        }
3214        node_idx++;
3215
3216        static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
3217
3218        for (const ggml_op op : remaining_ops) {
3219            if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3220                return false;
3221            }
3222            node_idx++;
3223        }
3224    }
3225
3226    // At this point we can check for norm + scale. Everything is now at least valid till the norm
3227    if (node_idx >= n_nodes) {
3228        return true;
3229    }
3230
3231    if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
3232        //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
3233        static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
3234
3235        args.norm = true;
3236        for (const ggml_op op : norm_ops) {
3237            if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3238                node_idx++;
3239            } else {
3240                args.norm = false;
3241                return true;
3242            }
3243        }
3244
3245        // DIV <- CLAMP, RESHAPE
3246        if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
3247            nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
3248            args.norm = false;
3249            return true;
3250        }
3251        node_idx++;
3252
3253        if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
3254            args.norm = false;
3255            return true;
3256        }
3257
3258        node_idx++;
3259    }
3260
3261    if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
3262        args.scale = true;
3263    }
3264
3265    return true;
3266}
3267
3268static bool ggml_cuda_can_fuse(const struct ggml_cgraph *                cgraph,
3269                               int                                       node_idx,
3270                               std::initializer_list<enum ggml_op>       ops,
3271                               std::initializer_list<enum ggml_unary_op> unary_ops) {
3272#ifndef NDEBUG
3273    const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
3274    GGML_ASSERT(unary_ops.size() == num_unary);
3275#endif
3276
3277    const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
3278                             const std::initializer_list<enum ggml_op> & list2) {
3279        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
3280    };
3281
3282    std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_MUL_MAT,    GGML_OP_ADD,    GGML_OP_GLU };
3283    std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
3284
3285    std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
3286    std::initializer_list<enum ggml_op> mul_mat_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_MUL_MAT,    GGML_OP_GLU };
3287
3288    if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
3289        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
3290        const ggml_tensor * ffn_gate      = cgraph->nodes[node_idx];
3291        const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
3292        const ggml_tensor * ffn_up        = cgraph->nodes[node_idx + 2];
3293        const ggml_tensor * ffn_up_bias   = cgraph->nodes[node_idx + 3];
3294        const ggml_tensor * glu           = cgraph->nodes[node_idx + 4];
3295
3296        if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) {
3297            return true;
3298        }
3299    }
3300
3301    if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
3302        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3303        const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
3304        const ggml_tensor * ffn_up   = cgraph->nodes[node_idx + 1];
3305        const ggml_tensor * glu      = cgraph->nodes[node_idx + 2];
3306
3307        if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) {
3308            return true;
3309        }
3310    }
3311
3312    std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
3313
3314    if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
3315        const ggml_tensor * rope     = cgraph->nodes[node_idx];
3316        const ggml_tensor * view     = cgraph->nodes[node_idx + 1];
3317        const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
3318
3319        if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) {
3320            return true;
3321        }
3322    }
3323
3324    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
3325        return false;
3326    }
3327
3328    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
3329        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
3330        const ggml_tensor *mul      = cgraph->nodes[node_idx+1];
3331        const ggml_tensor *add      = nullptr;
3332
3333        if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
3334            add = cgraph->nodes[node_idx+2];
3335        }
3336
3337        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
3338        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
3339
3340        //rms norm only supports F32
3341        if (mul->src[0]->type != GGML_TYPE_F32 ||
3342            mul->src[1]->type != GGML_TYPE_F32 ||
3343            mul->type != GGML_TYPE_F32) {
3344            return false;
3345        }
3346
3347        if (add && (add->src[0]->type != GGML_TYPE_F32 ||
3348            add->src[1]->type != GGML_TYPE_F32 ||
3349            add->type != GGML_TYPE_F32) ) {
3350            return false;
3351        }
3352
3353        //if rms norm is the B operand, then we don't handle broadcast
3354        if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) {
3355            return false;
3356        }
3357
3358        //rms_norm kernel assumes contigous rows
3359        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
3360            return false;
3361        }
3362
3363        if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
3364            return false;
3365        }
3366
3367        return true;
3368    }
3369
3370    if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE
3371     && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) {
3372        const ggml_tensor *scale  = cgraph->nodes[node_idx];
3373        const ggml_tensor *tanh   = cgraph->nodes[node_idx+1];
3374        const ggml_tensor *scale2 = cgraph->nodes[node_idx+2];
3375
3376        GGML_ASSERT(scale->src[0]->type == GGML_TYPE_F32);
3377        GGML_ASSERT(scale->type == GGML_TYPE_F32);
3378
3379        if (ggml_get_unary_op(tanh) != GGML_UNARY_OP_TANH) {
3380            return false;
3381        }
3382
3383        // Check for bias
3384        if (ggml_get_op_params_f32(scale, 1) != 0.0f || ggml_get_op_params_f32(scale2, 1) != 0.0f) {
3385            return false;
3386        }
3387
3388        return true;
3389    }
3390
3391    return false;
3392}
3393
3394static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
3395    bool graph_evaluated_or_captured = false;
3396
3397    // flag used to determine whether it is an integrated_gpu
3398    const bool integrated            = ggml_cuda_info().devices[cuda_ctx->device].integrated;
3399
3400    ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
3401    bool                         is_concurrent_event_active = false;
3402    ggml_cuda_concurrent_event * concurrent_event           = nullptr;
3403    bool                         should_launch_concurrent_events = false;
3404
3405    const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
3406        if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
3407            concurrent_event = &stream_ctx.concurrent_events[node];
3408
3409            is_concurrent_event_active = true;
3410
3411            GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
3412
3413            cudaStream_t main_stream = cuda_ctx->stream();  // this should be stream 0
3414            GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
3415            CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
3416
3417            for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3418                cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
3419                CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
3420            }
3421        }
3422    };
3423
3424    while (!graph_evaluated_or_captured) {
3425        // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
3426        // With the use of CUDA graphs, the execution will be performed by the graph launch.
3427        if (!use_cuda_graph || cuda_graph_update_required) {
3428            [[maybe_unused]] int prev_i = 0;
3429
3430            if (stream_ctx.concurrent_events.size() > 0) {
3431                should_launch_concurrent_events = true;
3432                for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
3433                    should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
3434                }
3435            }
3436
3437            if (should_launch_concurrent_events) {
3438                // Restore original node order within each concurrent region to enable fusion within streams
3439
3440                std::unordered_map<const ggml_tensor *, int> node_to_idx;
3441                node_to_idx.reserve(cgraph->n_nodes);
3442                for (int i = 0; i < cgraph->n_nodes; ++i) {
3443                    node_to_idx[cgraph->nodes[i]] = i;
3444                }
3445
3446                for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
3447                    // Find positions of all nodes from this event in the current graph
3448                    std::vector<int> positions;
3449                    positions.reserve(event.original_order.size());
3450
3451                    bool all_found = true;
3452                    for (const ggml_tensor * orig_node : event.original_order) {
3453                        auto it = node_to_idx.find(orig_node);
3454                        if (it != node_to_idx.end()) {
3455                            positions.push_back(it->second);
3456                        } else {
3457                            all_found = false;
3458                            break;
3459                        }
3460                    }
3461
3462                    if (!all_found || positions.size() != event.original_order.size()) {
3463                        continue;
3464                    }
3465
3466                    // Sort positions to get contiguous range
3467                    std::vector<int> sorted_positions = positions;
3468                    std::sort(sorted_positions.begin(), sorted_positions.end());
3469
3470                    bool is_contiguous = true;
3471                    for (size_t i = 1; i < sorted_positions.size(); ++i) {
3472                        if (sorted_positions[i] != sorted_positions[i-1] + 1) {
3473                            is_contiguous = false;
3474                            break;
3475                        }
3476                    }
3477
3478                    if (!is_contiguous) {
3479                        continue;
3480                    }
3481
3482                    // Restore original order at the sorted positions
3483                    int start_pos = sorted_positions[0];
3484                    for (size_t i = 0; i < event.original_order.size(); ++i) {
3485                        cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
3486                    }
3487                }
3488            } else {
3489                stream_ctx.concurrent_events.clear();
3490            }
3491
3492            for (int i = 0; i < cgraph->n_nodes; i++) {
3493                ggml_tensor * node = cgraph->nodes[i];
3494                if (is_concurrent_event_active) {
3495                    GGML_ASSERT(concurrent_event);
3496
3497                    if (node == concurrent_event->join_node) {
3498                        cuda_ctx->curr_stream_no = 0;
3499                        for (int i = 1; i <= concurrent_event->n_streams; ++i) {
3500                            // Wait on join events of forked streams in the main stream
3501                            CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
3502                                                       cuda_ctx->stream(cuda_ctx->device, i)));
3503                            CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
3504                        }
3505
3506                        is_concurrent_event_active = false;
3507                        concurrent_event           = nullptr;
3508                    } else {
3509                        GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
3510                        cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3511                        GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3512                    }
3513                } else if (i - prev_i > 1) {
3514                    //the previous node was fused
3515                    const ggml_tensor * prev_node = cgraph->nodes[i - 1];
3516                    try_launch_concurrent_event(prev_node);
3517
3518                    if (is_concurrent_event_active) {
3519                        cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
3520                        GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
3521                    }
3522                }
3523
3524#ifdef GGML_CUDA_DEBUG
3525                const int nodes_fused = i - prev_i - 1;
3526                if (nodes_fused > 0) {
3527                    GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
3528                }
3529#endif
3530                prev_i = i;
3531
3532                if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
3533                    continue;
3534                }
3535
3536                if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
3537                    continue;
3538                }
3539
3540                // start of fusion operations
3541                static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
3542                if (!disable_fusion) {
3543                    ggml_cuda_topk_moe_args args;
3544
3545                    if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
3546                        cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
3547                        const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
3548
3549                        std::vector<ggml_op> ops;
3550
3551                        if (can_fuse) {
3552                            const ggml_tensor * logits  = node->src[0];
3553                            ggml_tensor *       weights = nullptr;
3554                            ggml_tensor *       ids     = nullptr;
3555                            const ggml_tensor * bias    = nullptr;
3556                            const ggml_tensor * clamp   = nullptr;
3557                            const ggml_tensor * scale   = nullptr;
3558
3559                            if (!args.delayed_softmax) {
3560                                ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
3561                                int     out_nodes[2];  // nodes which can't be elided
3562
3563                                if (args.prob_bias) {
3564                                    bias = cgraph->nodes[i + 2]->src[1];
3565                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
3566                                                            GGML_OP_VIEW, GGML_OP_GET_ROWS });
3567                                    out_nodes[0] = i + 4;
3568                                    ids          = cgraph->nodes[i + 4];
3569                                } else {
3570                                    ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
3571                                                            GGML_OP_GET_ROWS });
3572                                    out_nodes[0] = i + 3;
3573                                    ids          = cgraph->nodes[i + 3];
3574                                }
3575
3576                                if (args.norm) {
3577                                    ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
3578                                                            GGML_OP_DIV, GGML_OP_RESHAPE });
3579                                    clamp = cgraph->nodes[i + ops.size() - 3];
3580                                }
3581                                if (args.scale) {
3582                                    ops.insert(ops.end(), { GGML_OP_SCALE });
3583                                    scale = cgraph->nodes[i + ops.size() - 1];
3584                                }
3585
3586                                weights      = cgraph->nodes[i + ops.size() - 1];
3587                                out_nodes[1] = i + ops.size() - 1;
3588
3589                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3590                                        ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
3591                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3592                                    i += ops.size() - 1;
3593                                    continue;
3594                                }
3595                            } else if (!args.norm && !args.prob_bias) {
3596                                //special case gpt-oss, no norm, no bias.
3597                                ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
3598                                                        GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
3599                                weights                     = cgraph->nodes[i + 5];
3600                                ids                         = cgraph->nodes[i + 1];
3601                                const ggml_tensor * softmax = cgraph->nodes[i + 4];
3602
3603                                int out_nodes[2] = { i + 1, i + 5 };
3604                                if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
3605                                        ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
3606                                    ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
3607                                    i += ops.size() - 1;
3608                                    continue;
3609                                }
3610                            }
3611                        }
3612                    }
3613
3614                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
3615                        ggml_tensor * rope = cgraph->nodes[i];
3616                        ggml_tensor * set_rows = cgraph->nodes[i + 2];
3617
3618                        ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows);
3619                        i += 2;
3620                        continue;
3621                    }
3622
3623                    if (node->op == GGML_OP_ADD) {
3624                        int n_fuse = 0;
3625                        ggml_op ops[8];
3626                        std::fill(ops, ops + 8, GGML_OP_ADD);
3627
3628                        for (; n_fuse <= 6; ++n_fuse){
3629                            if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
3630                                break;
3631                            }
3632                            if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
3633                                break;
3634                            }
3635                            if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
3636                                break;
3637                            }
3638                        }
3639
3640                        n_fuse++;
3641
3642                        if (n_fuse > 1) {
3643                            for (int j = 0; j < n_fuse - 1; ++j) {
3644                                node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
3645                            }
3646                            cgraph->nodes[i + n_fuse - 1]->data = node->data;
3647                            ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
3648                            i += n_fuse - 1;
3649
3650                            continue;
3651                        }
3652                    }
3653
3654                    bool fused_mul_mat_vec = false;
3655                    int fused_node_count = 0;
3656
3657                    for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3658                        const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3659
3660                        if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) {
3661                            ggml_tensor * glu         = cgraph->nodes[i + 4];
3662                            ggml_tensor * gate_bias_n = glu->src[0];
3663                            ggml_tensor * up_bias_n   = glu->src[1];
3664
3665                            //we don't assume the order for {gate, up}. Instead infer it from the bias tensor
3666                            ggml_tensor * gate_n      = nullptr;
3667                            ggml_tensor * up_n        = nullptr;
3668
3669                            if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) {
3670                                gate_n = cgraph->nodes[i];
3671                                up_n   = cgraph->nodes[i + 2];
3672                            } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) {
3673                                gate_n = cgraph->nodes[i + 2];
3674                                up_n   = cgraph->nodes[i];
3675                            } else {
3676                                continue;
3677                            }
3678
3679                            auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) {
3680                                if (op_bias == GGML_OP_ADD) {
3681                                    if (bias_node->src[0] == mul_node) {
3682                                        return bias_node->src[1];
3683                                    }
3684                                    if (bias_node->src[1] == mul_node) {
3685                                        return bias_node->src[0];
3686                                    }
3687                                    return (ggml_tensor *) nullptr;
3688                                }
3689                                GGML_ASSERT(op_bias == GGML_OP_ADD_ID);
3690                                GGML_ASSERT(bias_node->src[0] == mul_node);
3691                                return bias_node->src[1];
3692                            };
3693
3694                            ggml_tensor * up_bias_tensor   = get_bias_tensor(up_bias_n, up_n, bias_op);
3695                            ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op);
3696
3697                            if (!up_bias_tensor || !gate_bias_tensor) {
3698                                continue;
3699                            }
3700
3701                            // we don't support repeating adds
3702                            if (bias_op == GGML_OP_ADD &&
3703                                (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) ||
3704                                 !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) {
3705                                continue;
3706                            }
3707
3708                            const ggml_tensor * src0 = up_n->src[0];
3709                            const ggml_tensor * src1 = up_n->src[1];
3710                            const ggml_tensor * ids  = up_n->src[2];
3711
3712                            if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) {
3713                                ggml_cuda_mm_fusion_args_host fusion_data{};
3714                                fusion_data.gate      = gate_n->src[0];
3715                                fusion_data.x_bias    = up_bias_tensor;
3716                                fusion_data.gate_bias = gate_bias_tensor;
3717                                fusion_data.glu_op    = ggml_get_glu_op(glu);
3718
3719                                ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3720                                fused_mul_mat_vec = true;
3721                                fused_node_count = 5;
3722                                break;
3723                            }
3724
3725                            if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) {
3726                                ggml_cuda_mm_fusion_args_host fusion_data{};
3727                                fusion_data.gate      = gate_n->src[0];
3728                                fusion_data.x_bias    = up_bias_tensor;
3729                                fusion_data.gate_bias = gate_bias_tensor;
3730                                fusion_data.glu_op    = ggml_get_glu_op(glu);
3731
3732                                ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3733                                fused_mul_mat_vec = true;
3734                                fused_node_count = 5;
3735                                break;
3736                            }
3737                        } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) {
3738                            ggml_tensor * glu  = cgraph->nodes[i + 2];
3739                            ggml_tensor * gate = glu->src[0];
3740                            ggml_tensor * up   = glu->src[1];
3741
3742                            bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1])
3743                                || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]);
3744
3745                            if (!ok) continue;
3746
3747                            const ggml_tensor * src0 = up->src[0];
3748                            const ggml_tensor * src1 = up->src[1];
3749                            const ggml_tensor * ids  = up->src[2];
3750
3751                            if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) {
3752                                ggml_cuda_mm_fusion_args_host fusion_data{};
3753                                fusion_data.gate   = gate->src[0];
3754                                fusion_data.glu_op = ggml_get_glu_op(glu);
3755
3756                                ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3757                                fused_mul_mat_vec = true;
3758                                fused_node_count = 3;
3759                                break;
3760                            }
3761
3762                            if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) {
3763                                ggml_cuda_mm_fusion_args_host fusion_data{};
3764                                fusion_data.gate   = gate->src[0];
3765                                fusion_data.glu_op = ggml_get_glu_op(glu);
3766
3767                                ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data);
3768                                fused_mul_mat_vec = true;
3769                                fused_node_count = 3;
3770                                break;
3771                            }
3772                        }
3773                    }
3774
3775                    if (fused_mul_mat_vec) {
3776                        i += fused_node_count - 1;
3777                        continue;
3778                    }
3779
3780                    fused_mul_mat_vec = false;
3781                    fused_node_count = 0;
3782
3783                    for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) {
3784                        const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID;
3785
3786                        if (!ggml_can_fuse(cgraph, i, { op, bias_op })) {
3787                            continue;
3788                        }
3789
3790                        ggml_tensor * mm_node   = cgraph->nodes[i];
3791                        ggml_tensor * bias_node = cgraph->nodes[i + 1];
3792
3793                        ggml_tensor * bias_tensor = nullptr;
3794                        if (bias_op == GGML_OP_ADD) {
3795                            if (bias_node->src[0] == mm_node) {
3796                                bias_tensor = bias_node->src[1];
3797                            } else if (bias_node->src[1] == mm_node) {
3798                                bias_tensor = bias_node->src[0];
3799                            } else {
3800                                continue;
3801                            }
3802                        } else {
3803                            if (bias_node->src[0] != mm_node) {
3804                                continue;
3805                            }
3806                            bias_tensor = bias_node->src[1];
3807                        }
3808
3809                        const ggml_tensor * src0 = mm_node->src[0];
3810                        const ggml_tensor * src1 = mm_node->src[1];
3811                        const ggml_tensor * ids  = mm_node->src[2];
3812
3813                        if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) {
3814                            continue;
3815                        }
3816
3817                        if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) {
3818                            continue;
3819                        }
3820
3821                        ggml_cuda_mm_fusion_args_host fusion_data{};
3822                        fusion_data.x_bias = bias_tensor;
3823
3824                        if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) {
3825                            ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3826                            fused_mul_mat_vec = true;
3827                            fused_node_count = 2;
3828                            break;
3829                        }
3830
3831                        if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) {
3832                            ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data);
3833                            fused_mul_mat_vec = true;
3834                            fused_node_count = 2;
3835                            break;
3836                        }
3837                    }
3838
3839                    if (fused_mul_mat_vec) {
3840                        i += fused_node_count - 1;
3841                        continue;
3842                    }
3843
3844                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3845                        ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
3846                        i += 2;
3847                        continue;
3848                    }
3849
3850                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3851                        ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
3852                        i++;
3853                        continue;
3854                    }
3855
3856                    if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
3857                        i += 2;
3858                        ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
3859                        continue;
3860                    }
3861                }
3862#ifndef NDEBUG
3863                assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
3864                for (int j = 0; j < GGML_MAX_SRC; j++) {
3865                    if (node->src[j] != nullptr) {
3866                        assert(node->src[j]->buffer);
3867                        assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
3868                               ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
3869                    }
3870                }
3871#else
3872                GGML_UNUSED(integrated);
3873#endif  // NDEBUG
3874
3875                bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
3876                if (!ok) {
3877                    GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
3878                }
3879                GGML_ASSERT(ok);
3880
3881                if (!is_concurrent_event_active) {
3882                    try_launch_concurrent_event(node);
3883               }
3884            }
3885        }
3886
3887#ifdef USE_CUDA_GRAPH
3888        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3889        if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
3890            if (graph->graph != nullptr) {
3891                CUDA_CHECK(cudaGraphDestroy(graph->graph));
3892                graph->graph = nullptr;
3893            }
3894
3895            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
3896            graph_evaluated_or_captured = true; // CUDA graph has been captured
3897
3898            std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3899            if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
3900                ggml_cuda_lock_cv.notify_all();
3901            }
3902        } else {
3903            graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
3904        }
3905    }
3906
3907    if (use_cuda_graph) {
3908        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3909        if (graph->instance == nullptr) { // Create executable graph from captured graph.
3910            CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
3911        }
3912        if (cuda_graph_update_required) { // Update graph executable
3913            ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
3914        }
3915        // Launch graph
3916        CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
3917#else
3918        GGML_UNUSED(graph_key);
3919        graph_evaluated_or_captured = true;
3920#endif  // USE_CUDA_GRAPH
3921    }
3922}
3923
3924#ifdef USE_CUDA_GRAPH
3925static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
3926    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3927
3928    if (graph->graph == nullptr) {
3929        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
3930            if (!graph->disable_due_to_gpu_arch) {
3931                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
3932            }
3933            graph->disable_due_to_gpu_arch = true;
3934        }
3935    }
3936
3937    return graph->is_enabled();
3938}
3939#endif // USE_CUDA_GRAPH
3940
3941static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
3942    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
3943
3944    ggml_cuda_set_device(cuda_ctx->device);
3945
3946    bool use_cuda_graph             = false;
3947    bool cuda_graph_update_required = false;
3948    const void * graph_key = nullptr;
3949
3950#ifdef USE_CUDA_GRAPH
3951    graph_key = ggml_cuda_graph_get_key(cgraph);
3952
3953    use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
3954
3955    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
3956    if (graph->is_enabled()) {
3957        cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
3958        use_cuda_graph             = ggml_cuda_graph_check_compability(cgraph);
3959
3960        graph->record_update(use_cuda_graph, cuda_graph_update_required);
3961    }
3962#endif // USE_CUDA_GRAPH
3963
3964    if (use_cuda_graph && cuda_graph_update_required) {
3965        // Start CUDA graph capture
3966        {
3967            std::lock_guard<std::mutex> lock(ggml_cuda_lock);
3968            ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
3969        }
3970
3971        CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
3972    }
3973
3974    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
3975
3976    return GGML_STATUS_SUCCESS;
3977}
3978
3979static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
3980    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3981
3982    CUDA_CHECK(cudaEventRecord((cudaEvent_t)event->context, cuda_ctx->stream()));
3983}
3984
3985static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
3986    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
3987
3988    if (ggml_backend_is_cuda(backend)) {
3989        CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), (cudaEvent_t)event->context, 0));
3990    } else {
3991#if 0
3992        // untested
3993        auto wait_fn = [](void * user_data) {
3994            ggml_backend_event_t event = (ggml_backend_event_t)user_data;
3995            ggml_backend_event_synchronize(event);
3996        };
3997
3998        CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event));
3999#endif
4000        GGML_ABORT("fatal error");
4001    }
4002}
4003
4004static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
4005    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
4006
4007#ifdef USE_CUDA_GRAPH
4008    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
4009    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
4010#else
4011    const bool use_cuda_graph = false;
4012    GGML_UNUSED(cuda_ctx);
4013    GGML_UNUSED(cgraph);
4014#endif
4015
4016    static bool enable_graph_optimization = [] {
4017        const char * env     = getenv("GGML_CUDA_GRAPH_OPT");
4018        return env != nullptr && atoi(env) == 1;
4019    }();
4020
4021    if (!enable_graph_optimization) {
4022        return;
4023    }
4024
4025    ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
4026    stream_context.reset();
4027
4028    if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
4029        return;
4030    }
4031
4032    // number of out-degrees for a particular node
4033    std::unordered_map<const ggml_tensor *, int> fan_out;
4034    // reverse mapping of node to index in the cgraph
4035    std::unordered_map<const ggml_tensor *, int> node_indices;
4036
4037    const auto & is_noop = [](const ggml_tensor * node) -> bool {
4038        return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
4039               node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
4040    };
4041
4042    const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
4043        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
4044            if (dst->src[s] == src) {
4045                return true;
4046            }
4047        }
4048        // implicit dependency if they view the same tensor
4049        const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
4050        const ggml_tensor * src2 = src->view_src ? src->view_src : src;
4051        if (dst2 == src2) {
4052            return true;
4053        }
4054        return false;
4055    };
4056
4057    for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
4058        const ggml_tensor * node = cgraph->nodes[node_idx];
4059        node_indices[node]       = node_idx;
4060
4061        if (is_noop(node)) {
4062            continue;
4063        }
4064        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
4065            const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
4066            //TODO: check why nrows > 1 fails
4067            if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
4068                fan_out[src] += 1;
4069            }
4070        }
4071    }
4072
4073    // Target Q, K, V for concurrency
4074    // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
4075    // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
4076    // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
4077    // 3. account for all branches from the fork to the join
4078    // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
4079    // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
4080    // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
4081
4082    const int min_fan_out = 3;
4083    const int max_fan_out = 3;
4084
4085    // store {fork_idx, join_idx}
4086    std::vector<std::pair<int, int>> concurrent_node_ranges;
4087
4088    for (const auto & [root_node, count] : fan_out) {
4089        if (count >= min_fan_out && count <= max_fan_out) {
4090            const int root_node_idx = node_indices[root_node];
4091
4092            // only optimize for attn_norm
4093            // TODO: make this more generic
4094            if (!strstr(root_node->name, "attn_norm")) {
4095                continue;
4096            }
4097
4098            bool is_part_of_event = false;
4099            for (const auto & [start, end] : concurrent_node_ranges) {
4100                if (root_node_idx >= start && root_node_idx <= end) {
4101                    is_part_of_event = true;
4102                }
4103            }
4104
4105            if (is_part_of_event) {
4106                continue;
4107            }
4108
4109            std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
4110            for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
4111                const ggml_tensor * node = cgraph->nodes[i];
4112                if (!is_noop(node) && depends_on(node, root_node)) {
4113                    nodes_per_branch.push_back({ node });
4114                }
4115            }
4116
4117            GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
4118
4119            //find the join point
4120            const ggml_tensor * join_node = nullptr;
4121
4122            const auto & belongs_to_branch = [&](const ggml_tensor *                      node,
4123                                                 const std::vector<const ggml_tensor *> & branch) -> bool {
4124                for (const ggml_tensor * n : branch) {
4125                    if (depends_on(node, n)) {
4126                        return true;
4127                    }
4128                }
4129                return false;
4130            };
4131
4132            for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
4133                const ggml_tensor * curr_node = cgraph->nodes[i];
4134
4135                int num_joins = 0;
4136                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
4137                    if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
4138                        num_joins++;
4139                    }
4140                }
4141
4142                if (num_joins >= 2) {
4143                    join_node = curr_node;
4144                    break;
4145                }
4146
4147                bool found_branch = false;
4148                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
4149                    std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
4150                    if (belongs_to_branch(curr_node, branch_vec)) {
4151                        //continue accumulating
4152                        if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
4153                            branch_vec.push_back(curr_node);
4154                        }
4155                        found_branch = true;
4156                    }
4157                }
4158
4159                if (!found_branch && is_noop(curr_node)) {
4160                    // we can put it in any branch because it will be ignored
4161                    nodes_per_branch[0].push_back({ curr_node });
4162                }
4163            }
4164
4165            if (join_node) {
4166                //Create ggml_cuda_concurrent_event
4167                ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
4168                concurrent_event.join_node = join_node;
4169
4170                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
4171                    for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
4172                        concurrent_event.stream_mapping[n] = branch_idx + 1;
4173                    }
4174                }
4175
4176                int fork_node_idx = node_indices[root_node];
4177                int join_node_idx = node_indices[join_node];
4178
4179                int       current_branch_idx = 0;
4180                int       current_node_idx   = fork_node_idx + 1;
4181                const int n_branches         = nodes_per_branch.size();
4182
4183                int total_branch_nodes = 0;
4184                for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
4185                    total_branch_nodes += branch_nodes.size();
4186                }
4187
4188                // there are other nodes in the middle which are unaccounted for
4189                // usually (cpy) nodes, then ignore this fork
4190                if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
4191                    GGML_LOG_DEBUG(
4192                        "Skipping %s because the number of nodes in the middle is not equal to the total number of "
4193                        "branch nodes %d != %d\n",
4194                        root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
4195                    continue;
4196                }
4197
4198                // Save the original order of nodes in this region before interleaving
4199                // This is used later to restore grouping for fusion within streams
4200                concurrent_event.original_order.reserve(total_branch_nodes);
4201                for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
4202                    concurrent_event.original_order.push_back(cgraph->nodes[i]);
4203                }
4204
4205                std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
4206                GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
4207                concurrent_events.emplace(root_node, std::move(concurrent_event));
4208                GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
4209                concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
4210
4211                // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
4212                // example transformation:
4213                // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
4214                // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
4215                while (current_node_idx < join_node_idx) {
4216                    std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
4217
4218                    bool has_node = false;
4219                    for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
4220                        has_node |= branch_node.size() > 0;
4221                    }
4222
4223                    GGML_ASSERT(has_node);
4224
4225                    if (branch_nodes.empty()) {
4226                        current_branch_idx = (current_branch_idx + 1) % n_branches;
4227                        continue;
4228                    }
4229
4230                    cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4231                    current_node_idx++;
4232                    branch_nodes.erase(branch_nodes.begin());
4233
4234                    // append all empty nodes
4235                    while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
4236                        cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
4237                        current_node_idx++;
4238                        branch_nodes.erase(branch_nodes.begin());
4239                    }
4240
4241                    current_branch_idx = (current_branch_idx + 1) % n_branches;
4242                }
4243            }
4244        }
4245    }
4246}
4247
4248static const ggml_backend_i ggml_backend_cuda_interface = {
4249    /* .get_name                = */ ggml_backend_cuda_get_name,
4250    /* .free                    = */ ggml_backend_cuda_free,
4251    /* .set_tensor_async        = */ ggml_backend_cuda_set_tensor_async,
4252    /* .get_tensor_async        = */ ggml_backend_cuda_get_tensor_async,
4253    /* .cpy_tensor_async        = */ ggml_backend_cuda_cpy_tensor_async,
4254    /* .synchronize             = */ ggml_backend_cuda_synchronize,
4255    /* .graph_plan_create       = */ NULL,
4256    /* .graph_plan_free         = */ NULL,
4257    /* .graph_plan_update       = */ NULL,
4258    /* .graph_plan_compute      = */ NULL,
4259    /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
4260    /* .event_record            = */ ggml_backend_cuda_event_record,
4261    /* .event_wait              = */ ggml_backend_cuda_event_wait,
4262    /* .graph_optimize          = */ ggml_backend_cuda_graph_optimize,
4263};
4264
4265static ggml_guid_t ggml_backend_cuda_guid() {
4266    static ggml_guid guid = { 0x2c, 0xdd, 0xe8, 0x1c, 0x65, 0xb3, 0x65, 0x73, 0x6a, 0x12, 0x88, 0x61, 0x1c, 0xc9, 0xdc, 0x25 };
4267    return &guid;
4268}
4269
4270bool ggml_backend_is_cuda(ggml_backend_t backend) {
4271    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cuda_guid());
4272}
4273
4274int ggml_backend_cuda_get_device_count() {
4275    return ggml_cuda_info().device_count;
4276}
4277
4278void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size) {
4279    cudaDeviceProp prop;
4280    CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
4281    snprintf(description, description_size, "%s", prop.name);
4282}
4283
4284void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total) {
4285    ggml_cuda_set_device(device);
4286
4287    CUDA_CHECK(cudaMemGetInfo(free, total));
4288}
4289
4290bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
4291    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
4292        return false;
4293    }
4294
4295#if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA) || defined(GGML_USE_HIP)
4296    cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
4297    if (err != cudaSuccess) {
4298        // clear the error
4299        (void)cudaGetLastError();
4300
4301        GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
4302                           size / 1024.0 / 1024.0, cudaGetErrorString(err));
4303        return false;
4304    }
4305    return true;
4306#else
4307    GGML_UNUSED(buffer);
4308    GGML_UNUSED(size);
4309    return false;
4310#endif // CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
4311}
4312
4313void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
4314    if (getenv("GGML_CUDA_REGISTER_HOST") == nullptr) {
4315        return;
4316    }
4317
4318    cudaError_t err = cudaHostUnregister(buffer);
4319    if (err != cudaSuccess) {
4320        // clear the error
4321        (void)cudaGetLastError();
4322    }
4323}
4324
4325
4326// backend device
4327
4328struct ggml_backend_cuda_device_context {
4329    int device;
4330    std::string name;
4331    std::string description;
4332    std::string pci_bus_id;
4333    int op_offload_min_batch_size;
4334};
4335
4336static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
4337    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4338    return ctx->name.c_str();
4339}
4340
4341static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t dev) {
4342    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4343    return ctx->description.c_str();
4344}
4345
4346#if defined(__linux__)
4347// Helper function to get available memory from /proc/meminfo for UMA systems
4348static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_kb, long * free_swap_kb) {
4349    FILE * meminfo_file = nullptr;
4350    // 2KB buffer for reading /proc/meminfo since it does not report size info, should be enough
4351    const size_t BUFFER_SIZE = 2048;
4352    auto file_buffer = std::make_unique<char[]>(BUFFER_SIZE);
4353    size_t bytes_read = 0;
4354    long huge_tlb_total_pages = -1;
4355    long huge_tlb_free_pages = -1;
4356    long huge_tlb_page_size = -1;
4357
4358    if (available_memory_kb == nullptr || free_swap_kb == nullptr) {
4359        return false;
4360    }
4361
4362    meminfo_file = fopen("/proc/meminfo", "r");
4363    if (meminfo_file == nullptr) {
4364        GGML_LOG_ERROR("%s: failed to open /proc/meminfo\n", __func__);
4365        return false;
4366    }
4367
4368    // Read file into buffer
4369    bytes_read = fread(file_buffer.get(), 1, BUFFER_SIZE - 1, meminfo_file);
4370    fclose(meminfo_file);
4371
4372    if (bytes_read == 0) {
4373        GGML_LOG_ERROR("%s: failed to read from /proc/meminfo\n", __func__);
4374        return false;
4375    }
4376    file_buffer[bytes_read] = '\0';
4377
4378    *available_memory_kb = -1;
4379    *free_swap_kb = -1;
4380
4381    // Parse the file buffer line by line
4382    char * line = file_buffer.get();
4383    char * line_next;
4384    while (line < file_buffer.get() + bytes_read) {
4385        // Find the end of the current line
4386        line_next = strchr(line, '\n');
4387        if (line_next != nullptr) {
4388            *line_next = '\0';
4389            line_next++;
4390        } else {
4391            line_next = file_buffer.get() + bytes_read;
4392        }
4393
4394        long value;
4395        if (sscanf(line, "MemAvailable: %ld kB", &value) == 1) {
4396            *available_memory_kb = value;
4397        } else if (sscanf(line, "SwapFree: %ld kB", &value) == 1) {
4398            *free_swap_kb = value;
4399        } else if (sscanf(line, "HugePages_Total: %ld", &value) == 1) {
4400            huge_tlb_total_pages = value;
4401        } else if (sscanf(line, "HugePages_Free: %ld", &value) == 1) {
4402            huge_tlb_free_pages = value;
4403        } else if (sscanf(line, "Hugepagesize: %ld kB", &value) == 1) {
4404            huge_tlb_page_size = value;
4405        }
4406
4407        line = line_next;
4408    }
4409
4410    if (huge_tlb_total_pages != 0 && huge_tlb_total_pages != -1) {
4411        *available_memory_kb = huge_tlb_free_pages * huge_tlb_page_size;
4412
4413        // Hugetlbfs pages are not swappable.
4414        *free_swap_kb = 0;
4415    }
4416
4417    GGML_LOG_DEBUG("%s: final available_memory_kb: %ld\n", __func__, *available_memory_kb);
4418    return true;
4419}
4420#endif // defined(__linux__)
4421
4422static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
4423    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4424    ggml_cuda_set_device(ctx->device);
4425    CUDA_CHECK(cudaMemGetInfo(free, total));
4426
4427// ref: https://github.com/ggml-org/llama.cpp/pull/17368
4428#if defined(__linux__)
4429    // Check if this is a UMA (Unified Memory Architecture) system
4430    cudaDeviceProp prop;
4431    CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device));
4432
4433    // Check if UMA is explicitly enabled via environment variable
4434    bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
4435    bool is_uma = prop.integrated > 0 || uma_env;
4436
4437    if (is_uma) {
4438        // For UMA systems (like DGX Spark), use system memory info
4439        long available_memory_kb = 0;
4440        long free_swap_kb = 0;
4441
4442        if (ggml_backend_cuda_get_available_uma_memory(&available_memory_kb, &free_swap_kb) && available_memory_kb > 0) {
4443            *free = (size_t)available_memory_kb * 1024;
4444        } else {
4445            GGML_LOG_ERROR("%s: /proc/meminfo reading failed, using cudaMemGetInfo\n", __func__);
4446        }
4447    }
4448#endif // defined(__linux__)
4449
4450}
4451
4452static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) {
4453    GGML_UNUSED(dev);
4454    return GGML_BACKEND_DEVICE_TYPE_GPU;
4455}
4456
4457static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
4458    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4459
4460    props->name        = ggml_backend_cuda_device_get_name(dev);
4461    props->description = ggml_backend_cuda_device_get_description(dev);
4462    props->type        = ggml_backend_cuda_device_get_type(dev);
4463    props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
4464    ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
4465
4466    bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
4467#ifdef GGML_CUDA_NO_PEER_COPY
4468    bool events = false;
4469#else
4470    bool events = true;
4471#endif
4472
4473    props->caps = {
4474        /* .async                 = */ true,
4475        /* .host_buffer           = */ host_buffer,
4476        /* .buffer_from_host_ptr  = */ false,
4477        /* .events                = */ events,
4478    };
4479}
4480
4481static ggml_backend_t ggml_backend_cuda_device_init_backend(ggml_backend_dev_t dev, const char * params) {
4482    GGML_UNUSED(params);
4483    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4484    return ggml_backend_cuda_init(ctx->device);
4485}
4486
4487static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_buffer_type(ggml_backend_dev_t dev) {
4488    ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
4489    return ggml_backend_cuda_buffer_type(ctx->device);
4490}
4491
4492static ggml_backend_buffer_type_t ggml_backend_cuda_device_get_host_buffer_type(ggml_backend_dev_t dev) {
4493    GGML_UNUSED(dev);
4494    return ggml_backend_cuda_host_buffer_type();
4495}
4496
4497// TODO: move these functions here
4498static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4499    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
4500
4501    // split buffers can only be used with GGML_OP_MUL_MAT
4502    if (op->op != GGML_OP_MUL_MAT) {
4503        for (int i = 0; i < GGML_MAX_SRC; i++) {
4504            if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda_split(op->src[i]->buffer->buft)) {
4505                return false;
4506            }
4507        }
4508    }
4509
4510    // check if all the sources are allocated on this device
4511    for (int i = 0; i < GGML_MAX_SRC; i++) {
4512        if (op->src[i] && op->src[i]->buffer && ggml_backend_buft_is_cuda(op->src[i]->buffer->buft)) {
4513            ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)op->src[i]->buffer->buft->context;
4514            if (buft_ctx->device != dev_ctx->device) {
4515                return false;
4516            }
4517        }
4518    }
4519
4520    switch (op->op) {
4521        case GGML_OP_UNARY:
4522            switch (ggml_get_unary_op(op)) {
4523                case GGML_UNARY_OP_ABS:
4524                case GGML_UNARY_OP_SGN:
4525                case GGML_UNARY_OP_NEG:
4526                case GGML_UNARY_OP_STEP:
4527                case GGML_UNARY_OP_GELU:
4528                case GGML_UNARY_OP_SILU:
4529                case GGML_UNARY_OP_RELU:
4530                case GGML_UNARY_OP_SIGMOID:
4531                case GGML_UNARY_OP_HARDSIGMOID:
4532                case GGML_UNARY_OP_HARDSWISH:
4533                case GGML_UNARY_OP_GELU_ERF:
4534                case GGML_UNARY_OP_GELU_QUICK:
4535                case GGML_UNARY_OP_TANH:
4536                case GGML_UNARY_OP_EXP:
4537                case GGML_UNARY_OP_EXPM1:
4538                case GGML_UNARY_OP_SOFTPLUS:
4539                case GGML_UNARY_OP_ELU:
4540                case GGML_UNARY_OP_XIELU:
4541                case GGML_UNARY_OP_FLOOR:
4542                case GGML_UNARY_OP_CEIL:
4543                case GGML_UNARY_OP_ROUND:
4544                case GGML_UNARY_OP_TRUNC:
4545                    return ggml_is_contiguous(op->src[0]);
4546                default:
4547                    return false;
4548            }
4549            break;
4550        case GGML_OP_GLU:
4551            switch (ggml_get_glu_op(op)) {
4552                case GGML_GLU_OP_REGLU:
4553                case GGML_GLU_OP_GEGLU:
4554                case GGML_GLU_OP_SWIGLU:
4555                case GGML_GLU_OP_SWIGLU_OAI:
4556                case GGML_GLU_OP_GEGLU_ERF:
4557                case GGML_GLU_OP_GEGLU_QUICK:
4558                    return ggml_is_contiguous_1(op->src[0]);
4559                default:
4560                    return false;
4561            }
4562            break;
4563        case GGML_OP_MUL_MAT:
4564        case GGML_OP_MUL_MAT_ID:
4565            {
4566                struct ggml_tensor * a = op->src[0];
4567                struct ggml_tensor * b = op->src[1];
4568                if (a->buffer && ggml_backend_buft_is_cuda_split(a->buffer->buft)) {
4569                    if (a->ne[2] > 1 || a->ne[3] > 1) {
4570                        return false;
4571                    }
4572                    // for small weight matrices the active device can end up without any rows, don't use row split in those cases
4573                    // this avoids some edge cases (and the performance would not be good anyways)
4574                    ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) a->buffer->buft->context;
4575                    int64_t row_low;
4576                    int64_t row_high;
4577                    get_row_split(&row_low, &row_high, a, buft_ctx->tensor_split, dev_ctx->device);
4578                    if (row_low == row_high) {
4579                        return false;
4580                    }
4581                }
4582                if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
4583                    return false;
4584                }
4585#ifdef GGML_USE_MUSA
4586                const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
4587                if (b->ne[2]*b->ne[3] > 1 && !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
4588                    if (GGML_CUDA_CC_IS_QY1(cc) && op->op == GGML_OP_MUL_MAT &&
4589                            a->type == GGML_TYPE_F16 && b->type == GGML_TYPE_F16) {
4590                        return false;
4591                    }
4592                    if (GGML_CUDA_CC_IS_QY2(cc) && op->op == GGML_OP_MUL_MAT_ID &&
4593                            a->type == GGML_TYPE_Q2_K && b->type == GGML_TYPE_F32) {
4594                        return false;
4595                    }
4596                }
4597#endif // GGML_USE_MUSA
4598                switch (a->type) {
4599                    case GGML_TYPE_F32:
4600                    case GGML_TYPE_F16:
4601                    case GGML_TYPE_Q4_0:
4602                    case GGML_TYPE_Q4_1:
4603                    case GGML_TYPE_Q5_0:
4604                    case GGML_TYPE_Q5_1:
4605                    case GGML_TYPE_Q8_0:
4606                    case GGML_TYPE_MXFP4:
4607                    case GGML_TYPE_Q2_K:
4608                    case GGML_TYPE_Q3_K:
4609                    case GGML_TYPE_Q4_K:
4610                    case GGML_TYPE_Q5_K:
4611                    case GGML_TYPE_Q6_K:
4612                    case GGML_TYPE_Q8_K:
4613                    case GGML_TYPE_IQ1_M:
4614                    case GGML_TYPE_IQ1_S:
4615                    case GGML_TYPE_IQ2_S:
4616                    case GGML_TYPE_IQ2_XS:
4617                    case GGML_TYPE_IQ2_XXS:
4618                    case GGML_TYPE_IQ3_S:
4619                    case GGML_TYPE_IQ3_XXS:
4620                    case GGML_TYPE_IQ4_NL:
4621                    case GGML_TYPE_IQ4_XS:
4622                    case GGML_TYPE_BF16:
4623                        return true;
4624                    default:
4625                        return false;
4626                }
4627            } break;
4628        case GGML_OP_OUT_PROD:
4629            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
4630        case GGML_OP_GET_ROWS:
4631            {
4632                switch (op->src[0]->type) {
4633                    case GGML_TYPE_F16:
4634                    case GGML_TYPE_F32:
4635                    case GGML_TYPE_BF16:
4636                    case GGML_TYPE_I32:
4637                    case GGML_TYPE_Q4_0:
4638                    case GGML_TYPE_Q4_1:
4639                    case GGML_TYPE_Q5_0:
4640                    case GGML_TYPE_Q5_1:
4641                    case GGML_TYPE_Q8_0:
4642                        return true;
4643                    default:
4644                        return false;
4645                }
4646            } break;
4647        case GGML_OP_GET_ROWS_BACK:
4648            {
4649                return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1;
4650            } break;
4651        case GGML_OP_SET_ROWS:
4652            {
4653                return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
4654                       op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
4655                       op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
4656                       op->src[0]->type == GGML_TYPE_F32 &&
4657                       (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
4658            } break;
4659        case GGML_OP_SET:
4660            {
4661                const ggml_type t = op->type;
4662                return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
4663                    t == op->src[0]->type &&
4664                    t == op->src[1]->type;
4665            } break;
4666        case GGML_OP_CPY:
4667            {
4668                ggml_type src0_type = op->src[0]->type;
4669                ggml_type src1_type = op->src[1]->type;
4670                if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) &&
4671                    (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16)
4672                ) {
4673                    return true;
4674                }
4675                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
4676                    return true;
4677                }
4678                if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
4679                    return true;
4680                }
4681                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
4682                    return true;
4683                }
4684                if (src0_type == GGML_TYPE_Q4_0 && src1_type == GGML_TYPE_F32) {
4685                    return true;
4686                }
4687                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
4688                    return true;
4689                }
4690                if (src0_type == GGML_TYPE_Q4_1 && src1_type == GGML_TYPE_F32) {
4691                    return true;
4692                }
4693                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_0) {
4694                    return true;
4695                }
4696                if (src0_type == GGML_TYPE_Q5_0 && src1_type == GGML_TYPE_F32) {
4697                    return true;
4698                }
4699                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q5_1) {
4700                    return true;
4701                }
4702                if (src0_type == GGML_TYPE_Q5_1 && src1_type == GGML_TYPE_F32) {
4703                    return true;
4704                }
4705                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
4706                    return true;
4707                }
4708                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
4709                    return true;
4710                }
4711                if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
4712                    return true;
4713                }
4714                if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
4715                    return true;
4716                }
4717                if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
4718                    return true;
4719                }
4720                return false;
4721            } break;
4722        case GGML_OP_DUP:
4723            {
4724                ggml_type src0_type = op->src[0]->type;
4725                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4726            } break;
4727        case GGML_OP_ARGMAX:
4728        case GGML_OP_COUNT_EQUAL:
4729            {
4730                return true;
4731            } break;
4732        case GGML_OP_REPEAT:
4733            {
4734                ggml_type src0_type = op->src[0]->type;
4735                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4736            } break;
4737        case GGML_OP_REPEAT_BACK:
4738                return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
4739        case GGML_OP_CONCAT:
4740            {
4741                ggml_type src0_type = op->src[0]->type;
4742                return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
4743            } break;
4744        case GGML_OP_CONV_TRANSPOSE_1D:
4745            {
4746                ggml_type src0_type = op->src[0]->type;
4747                ggml_type src1_type = op->src[1]->type;
4748                if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
4749                    return true;
4750                }
4751                return false;
4752            } break;
4753        case GGML_OP_SILU_BACK:
4754            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
4755            break;
4756        case GGML_OP_NORM:
4757        case GGML_OP_RMS_NORM:
4758        case GGML_OP_L2_NORM:
4759            return true;
4760        case GGML_OP_RMS_NORM_BACK:
4761            return ggml_is_contiguous(op->src[0]);
4762            break;
4763        case GGML_OP_NONE:
4764        case GGML_OP_RESHAPE:
4765        case GGML_OP_VIEW:
4766        case GGML_OP_PERMUTE:
4767        case GGML_OP_TRANSPOSE:
4768        case GGML_OP_ADD:
4769        case GGML_OP_ADD_ID:
4770        case GGML_OP_ADD1:
4771        case GGML_OP_SUB:
4772        case GGML_OP_MUL:
4773        case GGML_OP_DIV:
4774        case GGML_OP_SCALE:
4775        case GGML_OP_SQR:
4776        case GGML_OP_SQRT:
4777        case GGML_OP_SIN:
4778        case GGML_OP_COS:
4779        case GGML_OP_CLAMP:
4780        case GGML_OP_LOG:
4781            return true;
4782        case GGML_OP_SSM_SCAN: {
4783            if (op->src[3]->ne[0] == 1) {
4784                // Mamba2
4785                // (kernel only supports (d_state == 128 || d_state == 256) && d_head % 16 == 0)
4786                return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % 16 == 0;
4787            } else {
4788                // Mamba
4789                // (kernel only supports d_state == 16, d_head == 1, n_head % 128 == 0, n_group == 1)
4790                return op->src[0]->ne[0] == 16 && op->src[0]->ne[1] == 1 && op->src[0]->ne[2] % 128 == 0 && op->src[4]->ne[1] == 1;
4791            }
4792        }
4793        case GGML_OP_SSM_CONV: {
4794            // assumes d_inner % threads == 0
4795            return op->src[0]->ne[1] % 128 == 0;
4796        }
4797        case GGML_OP_CONT:
4798            return true;
4799        case GGML_OP_DIAG_MASK_INF:
4800            return true;
4801        case GGML_OP_SOFT_MAX:
4802            return true;
4803        case GGML_OP_SOFT_MAX_BACK: {
4804            float max_bias = 0.0f;
4805            memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float));
4806            return max_bias == 0.0f;
4807        }
4808        case GGML_OP_ROLL:
4809            if(op->src[0]->type == GGML_TYPE_F32) {
4810                return true;
4811            }
4812            return false;
4813        case GGML_OP_ROPE:
4814        case GGML_OP_ROPE_BACK: {
4815            return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
4816        }
4817        case GGML_OP_IM2COL:
4818        case GGML_OP_IM2COL_3D:
4819        case GGML_OP_CONV_2D:
4820        case GGML_OP_CONV_2D_DW:
4821        case GGML_OP_CONV_TRANSPOSE_2D:
4822        case GGML_OP_POOL_2D:
4823        case GGML_OP_ACC:
4824            return true;
4825        case GGML_OP_SUM:
4826            return ggml_is_contiguous_rows(op->src[0]);
4827        case GGML_OP_TOP_K:
4828        case GGML_OP_ARGSORT:
4829#ifndef GGML_CUDA_USE_CUB
4830            return op->src[0]->ne[0] <= 1024;
4831#else
4832            return true;
4833#endif
4834        case GGML_OP_SUM_ROWS:
4835        case GGML_OP_MEAN:
4836        case GGML_OP_GROUP_NORM:
4837            return ggml_is_contiguous(op->src[0]);
4838        case GGML_OP_PAD:
4839            return true;
4840        case GGML_OP_UPSCALE:
4841        case GGML_OP_PAD_REFLECT_1D:
4842        case GGML_OP_ARANGE:
4843        case GGML_OP_TIMESTEP_EMBEDDING:
4844        case GGML_OP_LEAKY_RELU:
4845        case GGML_OP_RWKV_WKV6:
4846        case GGML_OP_GATED_LINEAR_ATTN:
4847        case GGML_OP_RWKV_WKV7:
4848            return true;
4849        case GGML_OP_FLASH_ATTN_EXT:
4850            return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
4851        case GGML_OP_CROSS_ENTROPY_LOSS:
4852        case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
4853        case GGML_OP_OPT_STEP_ADAMW:
4854        case GGML_OP_OPT_STEP_SGD:
4855        case GGML_OP_FILL:
4856        case GGML_OP_CUMSUM:
4857        case GGML_OP_TRI:
4858        case GGML_OP_DIAG:
4859        case GGML_OP_SOLVE_TRI:
4860            return true;
4861
4862        default:
4863            return false;
4864    }
4865}
4866
4867static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
4868    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
4869    const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
4870    return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
4871}
4872
4873static int64_t get_op_batch_size(const ggml_tensor * op) {
4874    switch (op->op) {
4875        case GGML_OP_GET_ROWS:
4876            return 0;
4877        case GGML_OP_MUL_MAT:
4878            return op->ne[1];
4879        case GGML_OP_MUL_MAT_ID:
4880        case GGML_OP_ROPE:
4881        case GGML_OP_ROPE_BACK:
4882            return op->ne[2];
4883        default:
4884            return ggml_nrows(op);
4885    }
4886}
4887
4888static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
4889    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
4890
4891    return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
4892}
4893
4894static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) {
4895#ifdef GGML_CUDA_NO_PEER_COPY
4896    return nullptr;
4897#else
4898    ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *)dev->context;
4899
4900    ggml_cuda_set_device(dev_ctx->device);
4901
4902    cudaEvent_t event;
4903    CUDA_CHECK(cudaEventCreateWithFlags(&event, cudaEventDisableTiming));
4904
4905    return new ggml_backend_event {
4906        /* .device  = */ dev,
4907        /* .context = */ event,
4908    };
4909#endif
4910}
4911
4912static void ggml_backend_cuda_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
4913    GGML_UNUSED(dev);
4914
4915    CUDA_CHECK(cudaEventDestroy((cudaEvent_t)event->context));
4916    delete event;
4917}
4918
4919static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
4920    GGML_UNUSED(dev);
4921    CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
4922}
4923
4924static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
4925    /* .get_name                = */ ggml_backend_cuda_device_get_name,
4926    /* .get_description         = */ ggml_backend_cuda_device_get_description,
4927    /* .get_memory              = */ ggml_backend_cuda_device_get_memory,
4928    /* .get_type                = */ ggml_backend_cuda_device_get_type,
4929    /* .get_props               = */ ggml_backend_cuda_device_get_props,
4930    /* .init_backend            = */ ggml_backend_cuda_device_init_backend,
4931    /* .get_buffer_type         = */ ggml_backend_cuda_device_get_buffer_type,
4932    /* .get_host_buffer_type    = */ ggml_backend_cuda_device_get_host_buffer_type,
4933    /* .buffer_from_host_ptr    = */ NULL,
4934    /* .supports_op             = */ ggml_backend_cuda_device_supports_op,
4935    /* .supports_buft           = */ ggml_backend_cuda_device_supports_buft,
4936    /* .offload_op              = */ ggml_backend_cuda_device_offload_op,
4937    /* .event_new               = */ ggml_backend_cuda_device_event_new,
4938    /* .event_free              = */ ggml_backend_cuda_device_event_free,
4939    /* .event_synchronize       = */ ggml_backend_cuda_device_event_synchronize,
4940};
4941
4942// backend reg
4943
4944struct ggml_backend_cuda_reg_context {
4945    std::vector<ggml_backend_dev_t> devices;
4946};
4947
4948static const char * ggml_backend_cuda_reg_get_name(ggml_backend_reg_t reg) {
4949    GGML_UNUSED(reg);
4950    return GGML_CUDA_NAME;
4951}
4952
4953static size_t ggml_backend_cuda_reg_get_device_count(ggml_backend_reg_t reg) {
4954    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
4955    return ctx->devices.size();
4956}
4957
4958static ggml_backend_dev_t ggml_backend_cuda_reg_get_device(ggml_backend_reg_t reg, size_t index) {
4959    ggml_backend_cuda_reg_context * ctx = (ggml_backend_cuda_reg_context *)reg->context;
4960    GGML_ASSERT(index < ctx->devices.size());
4961    return ctx->devices[index];
4962}
4963
4964static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t reg) {
4965    static std::vector<ggml_backend_feature> features = []() {
4966        std::vector<ggml_backend_feature> features;
4967    #define _STRINGIFY(...) #__VA_ARGS__
4968    #define STRINGIFY(...) _STRINGIFY(__VA_ARGS__)
4969
4970    #ifdef __CUDA_ARCH_LIST__
4971        features.push_back({ "ARCHS", STRINGIFY(__CUDA_ARCH_LIST__) });
4972    #endif
4973
4974    #ifdef GGML_CUDA_FORCE_MMQ
4975        features.push_back({ "FORCE_MMQ", "1" });
4976    #endif
4977
4978    #ifdef GGML_CUDA_FORCE_CUBLAS
4979        features.push_back({ "FORCE_CUBLAS", "1" });
4980    #endif
4981
4982    #ifndef GGML_USE_VMM
4983        features.push_back({ "NO_VMM", "1" });
4984    #endif
4985
4986    #ifdef GGML_CUDA_NO_PEER_COPY
4987        features.push_back({ "NO_PEER_COPY", "1" });
4988    #endif
4989
4990    #ifdef GGML_CUDA_USE_GRAPHS
4991        features.push_back({ "USE_GRAPHS", "1" });
4992    #endif
4993
4994    #ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE
4995        features.push_back({ "PEER_MAX_BATCH_SIZE", STRINGIFY(GGML_CUDA_PEER_MAX_BATCH_SIZE) });
4996    #endif
4997
4998    #ifdef GGML_CUDA_FA_ALL_QUANTS
4999        features.push_back({ "FA_ALL_QUANTS", "1" });
5000    #endif
5001
5002    {
5003        const auto & info = ggml_cuda_info();
5004        for (int id = 0; id < info.device_count; ++id) {
5005            if (blackwell_mma_available(info.devices[id].cc)) {
5006                features.push_back({ "BLACKWELL_NATIVE_FP4", "1"});
5007                break;
5008            }
5009        }
5010    }
5011
5012    #undef _STRINGIFY
5013    #undef STRINGIFY
5014
5015        features.push_back({ nullptr, nullptr });
5016
5017        return features;
5018    }();
5019
5020    return features.data();
5021
5022    GGML_UNUSED(reg);
5023}
5024
5025static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
5026    GGML_UNUSED(reg);
5027    if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
5028        return (void *)ggml_backend_cuda_split_buffer_type;
5029    }
5030    if (strcmp(name, "ggml_backend_register_host_buffer") == 0) {
5031        return (void *)ggml_backend_cuda_register_host_buffer;
5032    }
5033    if (strcmp(name, "ggml_backend_unregister_host_buffer") == 0) {
5034        return (void *)ggml_backend_cuda_unregister_host_buffer;
5035    }
5036    if (strcmp(name, "ggml_backend_get_features") == 0) {
5037        return (void *)ggml_backend_cuda_get_features;
5038    }
5039    return nullptr;
5040}
5041
5042static const ggml_backend_reg_i ggml_backend_cuda_reg_interface = {
5043    /* .get_name          = */ ggml_backend_cuda_reg_get_name,
5044    /* .get_device_count  = */ ggml_backend_cuda_reg_get_device_count,
5045    /* .get_device        = */ ggml_backend_cuda_reg_get_device,
5046    /* .get_proc_address  = */ ggml_backend_cuda_reg_get_proc_address,
5047};
5048
5049// backend registry
5050ggml_backend_reg_t ggml_backend_cuda_reg() {
5051    static ggml_backend_reg reg;
5052    static bool initialized = false;
5053
5054    {
5055        static std::mutex mutex;
5056        std::lock_guard<std::mutex> lock(mutex);
5057        if (!initialized) {
5058            ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
5059            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
5060
5061            for (int i = 0; i < ggml_cuda_info().device_count; i++) {
5062                ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
5063                dev_ctx->device = i;
5064                dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
5065
5066                cudaDeviceProp prop;
5067                CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
5068                dev_ctx->description = prop.name;
5069
5070                char pci_bus_id[16] = {};
5071                snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
5072                dev_ctx->pci_bus_id = pci_bus_id;
5073                dev_ctx->op_offload_min_batch_size = min_batch_size;
5074
5075                ggml_backend_dev_t dev = new ggml_backend_device {
5076                    /* .iface   = */ ggml_backend_cuda_device_interface,
5077                    /* .reg     = */ &reg,
5078                    /* .context = */ dev_ctx
5079                };
5080                ctx->devices.push_back(dev);
5081            }
5082
5083            reg = ggml_backend_reg {
5084                /* .api_version = */ GGML_BACKEND_API_VERSION,
5085                /* .iface       = */ ggml_backend_cuda_reg_interface,
5086                /* .context     = */ ctx
5087            };
5088        }
5089
5090        initialized = true;
5091    }
5092
5093    return &reg;
5094}
5095
5096ggml_backend_t ggml_backend_cuda_init(int device) {
5097    if (device < 0 || device >= ggml_backend_cuda_get_device_count()) {
5098        GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device);
5099        return nullptr;
5100    }
5101
5102    ggml_backend_cuda_context * ctx = new ggml_backend_cuda_context(device);
5103    if (ctx == nullptr) {
5104        GGML_LOG_ERROR("%s: failed to allocate context\n", __func__);
5105        return nullptr;
5106    }
5107
5108    ggml_backend_t cuda_backend = new ggml_backend {
5109        /* .guid    = */ ggml_backend_cuda_guid(),
5110        /* .iface   = */ ggml_backend_cuda_interface,
5111        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device),
5112        /* .context = */ ctx,
5113    };
5114
5115    return cuda_backend;
5116}
5117
5118GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg)