1/*
   2 * Copyright (c) 2023-2026 The ggml authors
   3 *
   4 * Permission is hereby granted, free of charge, to any person obtaining a copy
   5 * of this software and associated documentation files (the "Software"), to
   6 * deal in the Software without restriction, including without limitation the
   7 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
   8 * sell copies of the Software, and to permit persons to whom the Software is
   9 * furnished to do so, subject to the following conditions:
  10 *
  11 * The above copyright notice and this permission notice shall be included in
  12 * all copies or substantial portions of the Software.
  13 *
  14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
  19 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  20 * IN THE SOFTWARE.
  21 */
  22
  23#include "ggml-cann.h"
  24
  25#include "ggml-backend-impl.h"
  26#include "ggml-cann/aclnn_ops.h"
  27#include "ggml-cann/common.h"
  28#include "ggml-impl.h"
  29#include "ggml.h"
  30
  31#include <acl/acl.h>
  32#include <aclnnop/aclnn_trans_matmul_weight.h>
  33#include <stdarg.h>
  34
  35#include <chrono>
  36#include <cmath>
  37#include <cstdio>
  38#include <cstring>
  39#include <mutex>
  40#include <optional>
  41#include <queue>
  42#include <unordered_set>
  43
  44#define GGML_COMMON_DECL_C
  45
  46#include "ggml-common.h"
  47
  48#define GGML_CANN_NAME "CANN"
  49
  50/**
  51 * @brief Handles CANN errors by printing an error message and aborting.
  52 *
  53 * @param stmt The statement that caused the error.
  54 * @param func The function in which the error occurred.
  55 * @param file The file in which the error occurred.
  56 * @param line The line number where the error occurred.
  57 * @param msg The error message.
  58 */
  59[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
  60    int32_t id = -1;
  61    aclrtGetDevice(&id);
  62
  63    GGML_LOG_ERROR("CANN error: %s\n", msg);
  64    GGML_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func, file, line);
  65    GGML_LOG_ERROR("  %s\n", stmt);
  66    // abort with GGML_ASSERT to get a stack trace
  67    GGML_ABORT("CANN error");
  68}
  69
  70// Thread-local variable to record the current device of this thread.
  71thread_local int g_current_cann_device = -1;
  72
  73/**
  74 * @brief Set the CANN device to be used.
  75 *
  76 * @param device The target device ID to set.
  77 */
  78void ggml_cann_set_device(const int32_t device) {
  79    // int current_device = -1;
  80    // Note: In some CANN versions, if no device has been set yet,
  81    //       aclrtGetDevice(&current_device) may return 0 by default.
  82    // aclrtGetDevice(&current_device);
  83
  84    // If the current device is already the target one, no need to switch.
  85    if (device == g_current_cann_device) {
  86        return;
  87    }
  88
  89    // Switch to the new device.
  90    ACL_CHECK(aclrtSetDevice(device));
  91
  92    // Update the global device record.
  93    g_current_cann_device = device;
  94}
  95
  96/**
  97 * @brief Get the value of the specified environment variable (name) as lowercase.
  98 *        if not empty, return a std::string object
  99 */
 100std::optional<std::string> get_env_as_lowercase(const std::string & name) {
 101    const char * val = std::getenv(name.c_str());
 102    if (!val) {
 103        return std::nullopt;
 104    }
 105    std::string res = std::string(val);
 106    std::transform(res.begin(), res.end(), res.begin(), ::tolower);
 107    return res;
 108}
 109
 110/**
 111 * @brief Verify whether the environment variable is a valid value.
 112 */
 113bool parse_bool(const std::string & value) {
 114    static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
 115    return valid_values.find(value) != valid_values.end();
 116}
 117
 118/**
 119 * @brief Parse a string as an integer, returning 0 if invalid.
 120 *
 121 * This function attempts to convert the input string `value` to an `int`.
 122 * If the string is not a valid integer or is out of the `int` range,
 123 * it returns 0.
 124 *
 125 * @param value The string to parse.
 126 * @return The parsed integer, or 0 if conversion fails.
 127 */
 128int parse_integer(const std::string & value) {
 129    try {
 130        return std::stoi(value);
 131    } catch (...) {
 132        return 0;
 133    }
 134}
 135
 136/**
 137 * @brief Initialize the CANN device information.
 138 *
 139 * This function initializes the CANN device information by obtaining the
 140 * device count and setting the memory allocation granularity for each device.
 141 *
 142 * @return A structure containing the device information.
 143 */
 144static ggml_cann_device_info ggml_cann_init() {
 145    ggml_cann_device_info info = {};
 146
 147    aclError err = aclrtGetDeviceCount((uint32_t *) &info.device_count);
 148
 149    if (err != ACL_SUCCESS) {
 150        GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg());
 151        return info;
 152    }
 153
 154    GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
 155
 156    for (int id = 0; id < info.device_count; ++id) {
 157        aclrtPhysicalMemProp prop = {};
 158        prop.handleType           = ACL_MEM_HANDLE_TYPE_NONE;
 159        prop.allocationType       = ACL_MEM_ALLOCATION_TYPE_PINNED;
 160        prop.memAttr              = ACL_HBM_MEM_HUGE;
 161        prop.location.type        = ACL_MEM_LOCATION_TYPE_DEVICE;
 162        prop.location.id          = id;
 163        prop.reserve              = 0;
 164        err                       = aclrtMemGetAllocationGranularity(&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
 165                                                                     &info.devices[id].vmm_granularity);
 166        info.devices[id].vmm      = err == ACL_SUCCESS;
 167
 168        size_t free, total;
 169        ggml_backend_cann_get_device_memory(id, &free, &total);
 170        info.devices[id].total_vram = free;
 171    }
 172
 173    // TODO: add more device info later.
 174    return info;
 175}
 176
 177/**
 178 * @brief Retrieve the CANN device information.
 179 *
 180 * This function returns a reference to a structure containing the CANN device
 181 * information. The device information is initialized once and reused on
 182 * subsequent calls.
 183 *
 184 * @return A reference to the structure containing the device information.
 185 */
 186const ggml_cann_device_info & ggml_cann_info() {
 187    static ggml_cann_device_info info = ggml_cann_init();
 188    return info;
 189}
 190
 191//#define DEBUG_CANN_MALLOC
 192/**
 193 * @brief A pool of CANN buffers(priority segment buffer).
 194 *
 195 * This class manages a pool of CANN buffers for a specific device.
 196 */
 197struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
 198    /**
 199     * @brief The maximum reuse margin for a buffer.
 200     */
 201    static const size_t max_reuse_margin = 1ull << 22;  // 4MB
 202
 203    /**
 204     * @brief The minimum free margin for a buffer.
 205     */
 206    static const size_t min_free_margin = 1ull << 20;  // 1MB
 207
 208    /**
 209     * @brief The alignment for buffer allocation.
 210     */
 211    static const size_t alignment = 128;
 212
 213    /**
 214     * @brief The device ID associated with this buffer pool.
 215     */
 216    int device;
 217
 218    /**
 219     * @brief Whether to disable clean during buffer allocation.
 220     */
 221    bool disable_clean = false;
 222
 223    /**
 224     * @brief Structure representing a CANN buffer.
 225     */
 226    struct ggml_cann_buffer {
 227        void *                                ptr  = nullptr;  ///< Pointer to the buffer.
 228        size_t                                size = 0;        ///< Size of the buffer.
 229        std::chrono::steady_clock::time_point last_used;       ///< Last used time.
 230
 231        bool operator>(const ggml_cann_buffer & other) const { return size > other.size; }
 232    };
 233
 234    /**
 235     * @brief Array of CANN buffers in the pool.
 236     */
 237    std::unordered_map<void *, size_t>                                                   buffer_pool;
 238    std::priority_queue<ggml_cann_buffer, std::vector<ggml_cann_buffer>, std::greater<>> free_buffers;
 239
 240    /**
 241     * @brief Total size of all buffers in the pool.
 242     */
 243    size_t pool_size = 0;
 244
 245    /**
 246     * @brief Constructor to initialize the buffer pool for a specific device.
 247     *
 248     * @param device The device ID to associate with this buffer pool.
 249     */
 250    explicit ggml_cann_pool_buf_prio(int device) : device(device) {
 251        disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
 252    }
 253
 254    /**
 255     * @brief Destructor to free all buffers in the pool.
 256     */
 257    ~ggml_cann_pool_buf_prio() {
 258        ggml_cann_set_device(device);
 259        for (auto & [b_ptr, b_size] : buffer_pool) {
 260            aclrtFree(b_ptr);
 261            pool_size -= b_size;
 262        }
 263        buffer_pool.clear();
 264        GGML_ASSERT(pool_size == 0);
 265    }
 266
 267    /**
 268     * @brief Allocate a buffer of the given size.
 269     *
 270     * @param size The size of the buffer to allocate.
 271     * @param actual_size A pointer to a variable to receive the actual size of
 272     * the allocated buffer.
 273     * @return A pointer to the allocated buffer.
 274     */
 275    void * alloc(size_t size, size_t * actual_size) override {
 276        size = GGML_PAD(size, alignment);
 277        if (size == 0) {
 278            size = alignment;
 279        }
 280
 281        void * ptr = nullptr;
 282        auto   now = std::chrono::steady_clock::now();
 283
 284        std::vector<ggml_cann_buffer> free_buffers_rest;
 285        free_buffers_rest.reserve(free_buffers.size());
 286        while (!free_buffers.empty()) {
 287            auto b = free_buffers.top();
 288            free_buffers.pop();
 289
 290            if (b.size >= size) {
 291                // reuse the buffer if the size is enough
 292                const size_t margin = b.size - size;
 293                if (margin <= max_reuse_margin) {
 294                    *actual_size = b.size;
 295                    ptr          = b.ptr;
 296#ifdef DEBUG_CANN_MALLOC
 297                    GGML_LOG_INFO(
 298                        "cann pool[%d]: reused   %p, "
 299                        "pool_size = %5u MB, "
 300                        "size = %5u MB, "
 301                        "margin = %5u MB\n",
 302                        device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
 303                        (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
 304                        (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
 305#endif
 306                    break;
 307                }
 308            }
 309
 310            bool should_clean = !disable_clean && b.size > min_free_margin &&
 311                                std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
 312            if (should_clean) {
 313                // free the buffer if the size is needed to be freed
 314                ACL_CHECK(aclrtFree(b.ptr));
 315                pool_size -= b.size;
 316                buffer_pool.erase(b.ptr);
 317#ifdef DEBUG_CANN_MALLOC
 318                GGML_LOG_INFO(
 319                    "cann pool[%d]: clean    %p, "
 320                    "pool_size = %5u MB, "
 321                    "size = %5u MB\n",
 322                    device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
 323                    (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
 324#endif
 325                continue;
 326            }
 327            free_buffers_rest.push_back(b);
 328        }
 329        for (ggml_cann_buffer & b : free_buffers_rest) {
 330            free_buffers.push(std::move(b));
 331        }
 332
 333#ifdef DEBUG_CANN_MALLOC
 334        GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device,
 335                      (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
 336#endif
 337        if (ptr != nullptr) {
 338            return ptr;
 339        }
 340
 341        // allocate a new buffer if no buffer can be reused
 342        ggml_cann_set_device(device);
 343        ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
 344        *actual_size = size;
 345        pool_size += size;
 346#ifdef DEBUG_CANN_MALLOC
 347        GGML_LOG_INFO(
 348            "cann pool[%d]: allocate %p, "
 349            "pool_size = %5u MB, "
 350            "size = %5u MB\n",
 351            device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
 352            (uint32_t) (GGML_PAD(size, 1048576) / 1048576));
 353#endif
 354        buffer_pool.emplace(ptr, size);
 355        return ptr;
 356    }
 357
 358    /**
 359     * @brief Free a buffer and return it to the pool.
 360     *
 361     * @param ptr Pointer to the buffer to free.
 362     * @param size Size of the buffer to free.
 363     */
 364    void free(void * ptr, size_t size) override {
 365        GGML_UNUSED(size);
 366        auto it = buffer_pool.find(ptr);
 367        if (it == buffer_pool.end()) {
 368            GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr);
 369        }
 370
 371        auto now = std::chrono::steady_clock::now();
 372        free_buffers.emplace(ggml_cann_buffer{ ptr, it->second, now });
 373#ifdef DEBUG_CANN_MALLOC
 374        GGML_LOG_INFO(
 375            "cann pool[%d]: return   %p, "
 376            "pool_size = %5u MB\n",
 377            device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
 378#endif
 379    }
 380};
 381
 382/**
 383 * @brief A pool of CANN buffers(segment buffer).
 384 *
 385 * This class manages a pool of CANN buffers for a specific device.
 386 */
 387struct ggml_cann_pool_buf : public ggml_cann_pool {
 388    /**
 389     * @brief The maximum reuse margin for a buffer.
 390     */
 391    static const size_t max_reuse_margin = 1ull << 22;  // 4MB
 392
 393    /**
 394     * @brief The minimum free margin for a buffer.
 395     */
 396    static const size_t min_free_margin = 1ull << 20;  // 1MB
 397
 398    /**
 399     * @brief The alignment for buffer allocation.
 400     */
 401    static const size_t alignment = 128;
 402
 403    /**
 404     * @brief The maximum number of buffers in the pool.
 405     */
 406    static const int MAX_BUFFERS = 256;
 407
 408    /**
 409     * @brief The device ID associated with this buffer pool.
 410     */
 411    int device;
 412
 413    /**
 414     * @brief Whether to disable clean during buffer allocation.
 415     */
 416    bool disable_clean = false;
 417
 418    /**
 419     * @brief Structure representing a CANN buffer.
 420     */
 421    struct ggml_cann_buffer {
 422        void *                                ptr  = nullptr;  ///< Pointer to the buffer memory.
 423        size_t                                size = 0;        ///< Size of the buffer.
 424        bool                                  used = false;    ///< Whether the buffer is currently in use.
 425        std::chrono::steady_clock::time_point last_used;       ///< Last used time.
 426    };
 427
 428    /**
 429     * @brief Array of CANN buffers in the pool.
 430     */
 431    ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
 432
 433    /**
 434     * @brief Total size of all buffers in the pool.
 435     */
 436    size_t pool_size = 0;
 437
 438    /**
 439     * @brief Constructor to initialize the buffer pool for a specific device.
 440     *
 441     * @param device The device ID to associate with this buffer pool.
 442     */
 443    explicit ggml_cann_pool_buf(int device) : device(device) {
 444        disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
 445    }
 446
 447    /**
 448     * @brief Destructor to free all buffers in the pool.
 449     */
 450    ~ggml_cann_pool_buf() {
 451        ggml_cann_set_device(device);
 452        for (int i = 0; i < MAX_BUFFERS; ++i) {
 453            ggml_cann_buffer & b = buffer_pool[i];
 454            if (b.ptr != nullptr) {
 455                aclrtFree(b.ptr);
 456                pool_size -= b.size;
 457            }
 458        }
 459        GGML_ASSERT(pool_size == 0);
 460    }
 461
 462    /**
 463     * @brief Allocate a buffer of the given size.
 464     *
 465     * @param size The size of the buffer to allocate.
 466     * @param actual_size A pointer to a variable to receive the actual size of
 467     * the allocated buffer.
 468     * @return A pointer to the allocated buffer.
 469     */
 470    void * alloc(size_t size, size_t * actual_size) override {
 471        size = GGML_PAD(size, alignment);
 472        if (size == 0) {
 473            size = alignment;
 474        }
 475
 476        void * ptr = nullptr;
 477        auto   now = std::chrono::steady_clock::now();
 478
 479        int i = 0;
 480        for (; i < MAX_BUFFERS; ++i) {
 481            ggml_cann_buffer & b = buffer_pool[i];
 482            if (b.ptr == nullptr) {
 483                break;
 484            }
 485            if (b.used) {
 486                continue;
 487            }
 488            if (b.size >= size) {
 489                // reuse the buffer if the size is enough
 490                const size_t margin = b.size - size;
 491                if (margin <= max_reuse_margin) {
 492                    *actual_size = b.size;
 493                    b.used       = true;
 494                    ptr          = b.ptr;
 495#ifdef DEBUG_CANN_MALLOC
 496                    GGML_LOG_INFO(
 497                        "cann pool[%d]: reused   %p, "
 498                        "pool_size = %5u MB, "
 499                        "size = %5u MB, "
 500                        "margin = %5u MB\n",
 501                        device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
 502                        (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
 503                        (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
 504#endif
 505                    break;
 506                }
 507            }
 508
 509            bool should_clean = !disable_clean && b.size > min_free_margin &&
 510                                std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
 511            if (should_clean) {
 512                // free the buffer if the size is needed to be freed
 513                ACL_CHECK(aclrtFree(b.ptr));
 514                pool_size -= b.size;
 515#ifdef DEBUG_CANN_MALLOC
 516                GGML_LOG_INFO(
 517                    "cann pool[%d]: clean    %p, "
 518                    "pool_size = %5u MB, "
 519                    "size = %5u MB\n",
 520                    device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
 521                    (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
 522#endif
 523                b.ptr = nullptr;
 524            }
 525        }
 526        if (ptr != nullptr) {
 527            return ptr;
 528        }
 529
 530        if (i < MAX_BUFFERS) {
 531            // allocate a new buffer if no buffer can be reused
 532            ggml_cann_buffer & b = buffer_pool[i];
 533            ggml_cann_set_device(device);
 534            ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
 535            pool_size += size;
 536            *actual_size = size;
 537            b.size       = size;
 538            b.used       = true;
 539            if (i >= MAX_BUFFERS - 8) {
 540                GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device);
 541            }
 542#ifdef DEBUG_CANN_MALLOC
 543            GGML_LOG_INFO(
 544                "cann pool[%d]: allocate %p, "
 545                "pool_size = %5u MB, "
 546                "size = %5u MB\n",
 547                device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
 548                (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
 549#endif
 550            return b.ptr;
 551        }
 552
 553        GGML_ABORT("cann pool[%d]: slots full\n", device);
 554    }
 555
 556    /**
 557     * @brief Free a buffer and return it to the pool.
 558     *
 559     * @param ptr Pointer to the buffer to free.
 560     * @param size Size of the buffer to free.
 561     */
 562    void free(void * ptr, size_t size) override {
 563        GGML_UNUSED(size);
 564        for (int i = 0; i < MAX_BUFFERS; ++i) {
 565            ggml_cann_buffer & b = buffer_pool[i];
 566            if (b.ptr != ptr) {
 567                continue;
 568            }
 569            b.used      = false;
 570            b.last_used = std::chrono::steady_clock::now();
 571#ifdef DEBUG_CANN_MALLOC
 572            GGML_LOG_INFO(
 573                "cann pool[%d]: return   %p, "
 574                "pool_size = %5u MB\n",
 575                device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
 576#endif
 577            return;
 578        }
 579        GGML_ABORT("cann pool[%d]: slots full\n", device);
 580    }
 581};
 582
 583/**
 584 * @brief A pool of CANN buffers with virtual memory.
 585 *
 586 * This class manages a pool of CANN buffers with virtual memory for a specific
 587 * device.
 588 */
 589struct ggml_cann_pool_vmm : public ggml_cann_pool {
 590    /**
 591     * @brief The maximum size of the virtual memory pool (32 GB).
 592     */
 593    size_t max_size;
 594
 595    /**
 596     * @brief The device ID associated with this buffer pool.
 597     */
 598    int device;
 599
 600    /**
 601     * @brief Pointer to the start of the virtual memory pool.
 602     */
 603    void * pool_addr = 0;
 604
 605    /**
 606     * @brief Amount of virtual memory used in the pool.
 607     */
 608    size_t pool_used = 0;
 609
 610    /**
 611     * @brief Total size of the virtual memory pool.
 612     */
 613    size_t pool_size = 0;
 614
 615    /**
 616     * @brief Allocation granularity for the virtual memory pool.
 617     */
 618    size_t granularity;
 619
 620    /**
 621     * @brief Handles for the physical memory allocated.
 622     */
 623    std::vector<aclrtDrvMemHandle> handles;
 624
 625    /**
 626     * @brief Offsets for the mapped memory regions.
 627     */
 628    std::vector<void *> map_offsets;
 629
 630    /**
 631     * @brief Constructor to initialize the buffer pool with virtual memory for
 632     * a specific device.
 633     *
 634     * @param device The device ID to associate with this buffer pool.
 635     */
 636    explicit ggml_cann_pool_vmm(int device) : device(device) {
 637        auto dev    = ggml_cann_info().devices[device];
 638        granularity = dev.vmm_granularity;
 639        max_size    = dev.total_vram;
 640    }
 641
 642    /**
 643     * @brief Destructor to free all buffers in the virtual memory pool.
 644     */
 645    ~ggml_cann_pool_vmm() {
 646        if (pool_addr != 0) {
 647            for (auto & offset : map_offsets) {
 648                ACL_CHECK(aclrtUnmapMem(offset));
 649            }
 650            for (auto & handle : handles) {
 651                ACL_CHECK(aclrtFreePhysical(handle));
 652            }
 653            ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
 654        }
 655    }
 656
 657    /**
 658     * @brief Allocate a buffer of the given size in the virtual memory pool.
 659     *
 660     * @param size The size of the buffer to allocate.
 661     * @param actual_size A pointer to a variable to receive the actual size of
 662     * the allocated buffer.
 663     * @return A pointer to the allocated buffer.
 664     */
 665    void * alloc(size_t size, size_t * actual_size) override {
 666        // round up the allocation size to the alignment to ensure that all
 667        // allocations are aligned for all data types
 668        const size_t alignment = 128;
 669        size                   = GGML_PAD(size, alignment);
 670        if (size == 0) {
 671            size = alignment;
 672        }
 673
 674        size_t avail = pool_size - pool_used;
 675
 676        if (size > avail) {
 677            // round up to the next multiple of the granularity
 678            size_t reserve_size = size - avail;
 679            reserve_size        = GGML_PAD(reserve_size, granularity);
 680
 681            GGML_ASSERT(pool_size + reserve_size <= max_size);
 682
 683            // allocate more physical memory
 684            aclrtPhysicalMemProp prop = {};
 685            prop.handleType           = ACL_MEM_HANDLE_TYPE_NONE;
 686            prop.allocationType       = ACL_MEM_ALLOCATION_TYPE_PINNED;
 687            prop.memAttr              = ACL_HBM_MEM_HUGE;
 688            prop.location.type        = ACL_MEM_LOCATION_TYPE_DEVICE;
 689            prop.location.id          = device;
 690            prop.reserve              = 0;
 691            aclrtDrvMemHandle handle;
 692            ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
 693
 694            // reserve virtual address space (if not already reserved)
 695            if (pool_addr == 0) {
 696                ACL_CHECK(aclrtReserveMemAddress(&pool_addr, max_size, 0, NULL, 1));
 697            }
 698
 699            // map at the end of the pool
 700            ACL_CHECK(aclrtMapMem((char *) pool_addr + pool_size, reserve_size, 0, handle, 0));
 701
 702            handles.push_back(handle);
 703            map_offsets.push_back((char *) pool_addr + pool_size);
 704
 705            // add to the pool
 706            pool_size += reserve_size;
 707
 708#ifdef DEBUG_CANN_MALLOC
 709            GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", device,
 710                          (unsigned long long) (pool_size / 1024 / 1024),
 711                          (unsigned long long) (reserve_size / 1024 / 1024));
 712#endif
 713        }
 714
 715        GGML_ASSERT(pool_addr != 0);
 716
 717        void * ptr   = (void *) ((char *) pool_addr + pool_used);
 718        *actual_size = size;
 719        pool_used += size;
 720
 721#ifdef DEBUG_CANN_MALLOC
 722        GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size,
 723                      (unsigned long long) ptr);
 724#endif
 725        return ptr;
 726    }
 727
 728    /**
 729     * @brief Free a buffer and return it to the virtual memory pool.
 730     *
 731     * @param ptr Pointer to the buffer to free.
 732     * @param size Size of the buffer to free.
 733     */
 734    void free(void * ptr, size_t size) override {
 735#ifdef DEBUG_CANN_MALLOC
 736        GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size,
 737                      (unsigned long long) ptr);
 738#endif
 739
 740        pool_used -= size;
 741
 742        // all deallocations must be in reverse order of the allocations
 743        GGML_ASSERT(ptr == (void *) ((char *) pool_addr + pool_used));
 744    }
 745};
 746
 747/**
 748 * @brief Create a new CANN pool for a specific device.
 749 *
 750 * Factory method to create a new CANN pool object based on the device type.
 751 *
 752 * @param device The device ID for which to create the pool.
 753 * @return A unique pointer to the created CANN pool.
 754 */
 755std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
 756    std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
 757
 758    if (mem_pool_type == "prio") {
 759        GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
 760        return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
 761    }
 762
 763    if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
 764        GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
 765        return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
 766    }
 767
 768    GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
 769    return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
 770}
 771
 772// cann buffer
 773/**
 774 * @brief Context for managing a CANN buffer associated with a specific device.
 775 *
 776 * This structure holds information about a CANN buffer, including the device
 777 * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
 778 */
 779struct ggml_backend_cann_buffer_context {
 780    int32_t device;             ///< The device ID associated with this buffer context.
 781    void *  dev_ptr = nullptr;  ///< Pointer to the device memory allocated for the buffer.
 782
 783    /**
 784     * @brief Constructor to initialize the CANN buffer context.
 785     *
 786     * @param device The device ID associated with this buffer context.
 787     * @param dev_ptr Pointer to the device memory allocated for the buffer.
 788     */
 789    ggml_backend_cann_buffer_context(int32_t device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
 790
 791    /**
 792     * @brief Destructor to free the device memory allocated for the buffer.
 793     */
 794    ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
 795};
 796
 797// cann buffer type
 798/**
 799 * @brief Structure representing context information for a specific backend
 800 * buffer type.
 801 */
 802struct ggml_backend_cann_buffer_type_context {
 803    int32_t     device; /**< Device identifier associated with the buffer context. */
 804    std::string name;   /**< Name associated with the buffer context. */
 805};
 806
 807/**
 808 * @brief Retrieves the name associated with a CANN buffer type.
 809 *
 810 * This function returns the descriptive name associated with the specified
 811 * CANN buffer type context.
 812 *
 813 * @param buft Pointer to the buffer type context.
 814 * @return Const pointer to the C-style string containing the name.
 815 */
 816static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
 817    ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
 818
 819    return buft_ctx->name.c_str();
 820}
 821
 822/**
 823 * @brief Checks if the backend buffer type is associated with the CANN backend.
 824 *
 825 * This function checks whether the provided backend buffer type is associated
 826 * with the CANN backend based on the comparison of its name retrieval function
 827 * pointer.
 828 *
 829 * @param buft Pointer to the backend buffer type to check.
 830 * @return bool Returns true if the buffer type is associated with the CANN
 831 * backend, otherwise false.
 832 */
 833static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
 834    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
 835}
 836
 837/**
 838 * @brief Free resources associated with a CANN buffer.
 839 *
 840 * This function frees the resources associated with a CANN buffer, including
 841 * its context.
 842 *
 843 * @param buffer The CANN buffer to free.
 844 */
 845static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 846    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
 847    delete ctx;
 848}
 849
 850/**
 851 * @brief Retrieve the base pointer of a CANN buffer.
 852 *
 853 * This function returns the base pointer of a CANN buffer, which points to the
 854 * device memory allocated for the buffer.
 855 *
 856 * @param buffer The CANN buffer whose base pointer is to be retrieved.
 857 * @return A pointer to the base of the device memory allocated for the buffer.
 858 */
 859static void * ggml_backend_cann_buffer_get_base(ggml_backend_buffer_t buffer) {
 860    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
 861    return ctx->dev_ptr;
 862}
 863
 864/**
 865 * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
 866 * processing.
 867 *
 868 * This function transforms quantized Q4.0 tensor data into a format suitable
 869 * for CANN processing. It extracts quantization values and scales from the
 870 * source data and prepares them in a format expected by CANN operations.
 871 *
 872 * @param tensor Pointer to the tensor information.
 873 * @param src Pointer to the source data in Q4.0 format.
 874 * @param dst Pointer to the destination buffer where transformed data will be
 875 * stored.
 876 */
 877static void ggml_backend_cann_transform_q4_0(ggml_tensor * tensor, const void * src, void * dst) {
 878    int64_t n_elems     = ggml_nelements(tensor);
 879    int64_t groups      = n_elems / QK4_0;
 880    size_t  quant_bytes = n_elems * sizeof(uint8_t) / 2;
 881
 882    uint8_t *  quant_offset = (uint8_t *) dst;
 883    uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
 884
 885    for (int i = 0; i < groups; i++) {
 886        const block_q4_0 * group = (const block_q4_0 *) ((const char *) src + i * sizeof(block_q4_0));
 887        *scale_offset            = group->d;
 888        scale_offset++;
 889
 890        // 0-15
 891        for (int j = 0; j < QK4_0 / 2; j += 2) {
 892            (*quant_offset) = (group->qs[j] & 0x0F);
 893            (*quant_offset) |= ((group->qs[j + 1] << 4));
 894            quant_offset++;
 895        }
 896
 897        // 16-31
 898        for (int j = 0; j < QK4_0 / 2; j += 2) {
 899            (*quant_offset) = (group->qs[j] >> 4);
 900            (*quant_offset) |= (group->qs[j + 1] & 0xF0);
 901            quant_offset++;
 902        }
 903    }
 904
 905    // put (uint4b_t -8) into int4b_t
 906    for (quant_offset = (uint8_t *) dst; quant_offset < (uint8_t *) dst + quant_bytes; quant_offset++) {
 907        (*quant_offset) ^= 0x88;
 908    }
 909}
 910
 911/**
 912 * @brief Transform CANN processed data back into quantized Q4.0 format.
 913 *
 914 * This function transforms CANN processed data back into quantized Q4.0 format.
 915 * It reverses the transformation performed by
 916 * ggml_backend_cann_transform_q4_0(), converting the data back into its
 917 * original quantized form.
 918 *
 919 * @param tensor Pointer to the tensor information.
 920 * @param src Pointer to the source buffer containing transformed data.
 921 * @param dst Pointer to the destination buffer where the Q4.0 formatted data
 922 * will be stored.
 923 */
 924static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor * tensor, void * src, void * dst) {
 925    int64_t n_elems     = ggml_nelements(tensor);
 926    int64_t groups      = n_elems / QK4_0;
 927    size_t  quant_bytes = n_elems * sizeof(uint8_t) / 2;
 928
 929    uint8_t *  quant_offset = (uint8_t *) src;
 930    uint16_t * scale_offset = (uint16_t *) ((char *) src + quant_bytes);
 931
 932    for (; quant_offset < (uint8_t *) src + quant_bytes; quant_offset++) {
 933        (*quant_offset) ^= 0x88;
 934    }
 935    quant_offset = (uint8_t *) src;
 936
 937    for (int i = 0; i < groups; i++) {
 938        block_q4_0 * group = (block_q4_0 *) ((char *) dst + i * sizeof(block_q4_0));
 939        group->d           = *scale_offset;
 940        scale_offset++;
 941
 942        // 0-15
 943        for (int j = 0; j < QK4_0 / 2; j += 2) {
 944            group->qs[j]     = ((*quant_offset) & 0x0F);
 945            group->qs[j + 1] = ((*quant_offset) >> 4);
 946            quant_offset++;
 947        }
 948
 949        // 16-31
 950        for (int j = 0; j < QK4_0 / 2; j += 2) {
 951            group->qs[j] |= ((*quant_offset) << 4);
 952            group->qs[j + 1] |= ((*quant_offset) & 0xF0);
 953            quant_offset++;
 954        }
 955    }
 956}
 957
 958/**
 959 * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
 960 * processing.
 961 *
 962 * This function transforms quantized Q8.0 tensor data into a format suitable
 963 * for CANN processing. It extracts quantization values and scales from the
 964 * source data and prepares them in a format expected by CANN operations.
 965 *
 966 * @param tensor Pointer to the tensor information.
 967 * @param src Pointer to the source data in Q8.0 format.
 968 * @param dst Pointer to the destination buffer where transformed data will be
 969 * stored.
 970 */
 971static void ggml_backend_cann_transform_q8_0(ggml_tensor * tensor, const void * src, void * dst) {
 972    int64_t n_elems     = ggml_nelements(tensor);
 973    int64_t groups      = n_elems / QK8_0;
 974    size_t  quant_bytes = n_elems * sizeof(uint8_t);
 975
 976    uint8_t *  quant_offset = (uint8_t *) dst;
 977    uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
 978
 979    for (int i = 0; i < groups; i++) {
 980        const block_q8_0 * group = (const block_q8_0 *) ((const char *) src + i * sizeof(block_q8_0));
 981        *scale_offset            = group->d;
 982        scale_offset++;
 983        size_t group_quant_size = QK8_0 * sizeof(uint8_t);
 984        memcpy(quant_offset, group->qs, group_quant_size);
 985        quant_offset += group_quant_size;
 986    }
 987}
 988
 989/**
 990 * @brief Transform CANN processed data back into quantized Q8.0 format.
 991 *
 992 * This function transforms CANN processed data back into quantized Q8.0 format.
 993 * It reverses the transformation performed by
 994 * ggml_backend_cann_transform_q8_0(), converting the data back into its
 995 * original quantized form.
 996 *
 997 * @param tensor Pointer to the tensor information.
 998 * @param src Pointer to the source buffer containing transformed data.
 999 * @param dst Pointer to the destination buffer where the Q8.0 formatted data
1000 * will be stored.
1001 */
1002static void ggml_backend_cann_transform_back_q8_0(const ggml_tensor * tensor, const void * src, void * dst) {
1003    int64_t n_elems     = ggml_nelements(tensor);
1004    int64_t groups      = n_elems / QK8_0;
1005    size_t  quant_bytes = n_elems * sizeof(uint8_t);
1006
1007    const uint8_t *  quant_offset = (const uint8_t *) src;
1008    const uint16_t * scale_offset = (const uint16_t *) ((const char *) src + quant_bytes);
1009
1010    for (int i = 0; i < groups; i++) {
1011        block_q8_0 * group = (block_q8_0 *) ((char *) dst + i * sizeof(block_q8_0));
1012        group->d           = *scale_offset;
1013        scale_offset++;
1014        size_t group_quant_size = QK8_0 * sizeof(uint8_t);
1015        memcpy(group->qs, quant_offset, group_quant_size);
1016        quant_offset += group_quant_size;
1017    }
1018}
1019
1020/**
1021 * @brief Transform tensor data based on its type for CANN processing.
1022 *
1023 * This function transforms tensor data based on its quantization type for CANN
1024 * processing. It dispatches the transformation based on the tensor's type to
1025 * specialized functions handling Q4.0 and Q8.0 formats.
1026 *
1027 * @param tensor Pointer to the tensor information.
1028 * @param src Pointer to the source data to be transformed.
1029 * @param dst Pointer to the destination buffer where transformed data will be
1030 * stored.
1031 */
1032static void ggml_backend_cann_transform(ggml_tensor * tensor, const void * src, void * dst) {
1033    switch (tensor->type) {
1034        case GGML_TYPE_Q4_0:
1035            ggml_backend_cann_transform_q4_0(tensor, src, dst);
1036            break;
1037        case GGML_TYPE_Q8_0:
1038            ggml_backend_cann_transform_q8_0(tensor, src, dst);
1039            break;
1040        default:
1041            break;
1042    }
1043}
1044
1045/**
1046 * @brief Transform CANN processed data back into tensor data based on its type.
1047 *
1048 * This function transforms CANN processed data back into tensor data based on
1049 * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
1050 * transformation based on the tensor's type to specialized functions.
1051 *
1052 * @param tensor Pointer to the tensor information.
1053 * @param src Pointer to the source data containing CANN processed data.
1054 * @param dst Pointer to the destination buffer where transformed tensor data
1055 * will be stored.
1056 */
1057static void ggml_backend_cann_transform_back(const ggml_tensor * tensor, void * src, void * dst) {
1058    switch (tensor->type) {
1059        case GGML_TYPE_Q4_0:
1060            ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
1061            break;
1062        case GGML_TYPE_Q8_0:
1063            ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
1064            break;
1065        default:
1066            break;
1067    }
1068}
1069
1070/**
1071 * @brief Check if transformation is needed for a given tensor type.
1072 *
1073 * This function checks if transformation is needed for a given tensor type
1074 * to prepare data for CANN processing.
1075 *
1076 * @param type The tensor type to check.
1077 * @return true if transformation is needed, false otherwise.
1078 */
1079static bool need_transform(ggml_type type) {
1080    switch (type) {
1081        case GGML_TYPE_Q4_0:
1082        case GGML_TYPE_Q8_0:
1083            return true;
1084        default:
1085            return false;
1086    }
1087}
1088
1089/**
1090 * @brief Initialize a tensor using data from a CANN buffer.
1091 *
1092 * This function initializes a tensor using data from a CANN buffer.
1093 * It handles special cases such as views and quantization.
1094 *
1095 * @param buffer The CANN buffer from which to initialize the tensor.
1096 * @param tensor Pointer to the tensor to be initialized.
1097 */
1098static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1099    if (tensor->view_src != NULL && tensor->view_offs == 0) {
1100        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
1101        return GGML_STATUS_SUCCESS;
1102    }
1103
1104    // TODO: cann backend doesn't support quantized yet. Just leave the code
1105    // here.
1106    if (ggml_is_quantized(tensor->type)) {
1107        // Initialize padding to 0 to avoid possible NaN values
1108        size_t original_size = ggml_nbytes(tensor);
1109        size_t padded_size   = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
1110
1111        if (padded_size > original_size && tensor->view_src == nullptr) {
1112            size_t memset_size = padded_size - original_size;
1113            ACL_CHECK(aclrtMemset((char *) tensor->data + original_size, memset_size, 0, memset_size));
1114        }
1115    }
1116    return GGML_STATUS_SUCCESS;
1117}
1118
1119/**
1120 * @brief Workspace for caching NZ buffers per device.
1121 *
1122 * This struct manages a device buffer used in NZ computations. It supports
1123 * allocation, reallocation, and clearing of cached memory. The struct is
1124 * designed to be used with a global array, one per device.
1125 */
1126struct ggml_cann_nz_workspace {
1127    void * ptr;        // Pointer to allocated device buffer
1128    size_t allocated;  // Size of currently allocated buffer in bytes
1129
1130    /**
1131     * @brief Constructor. Initializes the workspace with no allocated memory.
1132     */
1133    ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
1134
1135    /**
1136     * @brief Free cached memory and reset the workspace.
1137     *
1138     * If a buffer has been allocated, this function releases it using
1139     * aclrtFree and resets internal state.
1140     */
1141    void clear() {
1142        if (ptr) {
1143            ACL_CHECK(aclrtFree(ptr));
1144            ptr       = nullptr;
1145            allocated = 0;
1146        }
1147    }
1148
1149    /**
1150     * @brief Allocate or reallocate the workspace buffer.
1151     *
1152     * If the requested size is larger than the currently allocated size,
1153     * the old buffer will be freed and a new buffer of the requested size
1154     * will be allocated on the device.
1155     *
1156     * @param new_size Size in bytes to allocate for the workspace.
1157     */
1158    void realloc(size_t new_size) {
1159        if (new_size > allocated) {
1160            clear();
1161            ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1162            allocated = new_size;
1163        }
1164    }
1165
1166    /**
1167     * @brief Get the device buffer pointer.
1168     *
1169     * @return Pointer to the allocated buffer, or nullptr if not allocated.
1170     */
1171    void * get() const { return ptr; }
1172};
1173
1174/**
1175 * @brief Global array of NZ workspaces, one per device.
1176 */
1177static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
1178
1179/**
1180 * @brief Convert tensor weights to NZ format using Ascend CANN API.
1181 *
1182 * This function creates a transposed tensor descriptor and performs the
1183 * TransMatmulWeight operation. Converting tensor formats can significantly
1184 * improve performance on certain hardware.
1185 *
1186 * @param tensor Pointer to the input ggml_tensor containing the weights.
1187 * @param offset Byte offset within the tensor data buffer where weights start.
1188 * @param device device id.
1189 *
1190 * @note The workspace buffer used in this function is managed globally and reused
1191 *       across calls. This reduces overhead from repeated memory allocation and deallocation.
1192 */
1193static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
1194    acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
1195    uint64_t       workspaceSize    = 0;
1196    aclOpExecutor * executor;
1197
1198    // TransMatmulWeight
1199    ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
1200    // Avoid frequent malloc/free of the workspace.
1201    g_nz_workspaces[device].realloc(workspaceSize);
1202
1203    void * g_nz_workspace = g_nz_workspaces[device].get();
1204
1205    ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
1206}
1207
1208// TODO: need handle tensor which has paddings.
1209/**
1210 * @brief Set tensor data in a CANN buffer.
1211 *
1212 * This function sets tensor data in a CANN buffer, handling transformations
1213 * if needed based on the tensor's type.
1214 *
1215 * @param buffer The CANN buffer where the tensor data will be set.
1216 * @param tensor Pointer to the tensor whose data will be set.
1217 * @param data Pointer to the source data to be copied into the tensor.
1218 * @param offset Offset in the source data from where to start copying.
1219 * @param size Size of the data to be copied, in bytes.
1220 */
1221static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
1222                                                ggml_tensor *         tensor,
1223                                                const void *          data,
1224                                                size_t                offset,
1225                                                size_t                size) {
1226    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1227
1228    ggml_cann_set_device(ctx->device);
1229    // TODO: refer to cann(#6017), it use thread's default stream.
1230    // For acl, synchronous functions use this default stream.
1231    // Why aclrtSynchronizeDevice?
1232
1233    // Only check env once.
1234    static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1235    if (!need_transform(tensor->type)) {
1236        ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1237        if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1238            GGML_ASSERT(tensor->ne[2] == 1);
1239            GGML_ASSERT(tensor->ne[3] == 1);
1240            weight_format_to_nz(tensor, offset, ctx->device);
1241        }
1242    } else {
1243        void * transform_buffer = malloc(size);
1244        ggml_backend_cann_transform(tensor, data, transform_buffer);
1245
1246        ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1247        free(transform_buffer);
1248    }
1249}
1250
1251/**
1252 * @brief Get tensor data from a CANN buffer.
1253 *
1254 * This function retrieves tensor data from a CANN buffer, handling
1255 * transformations if needed based on the tensor's type.
1256 *
1257 * @param buffer The CANN buffer from which to retrieve tensor data.
1258 * @param tensor Pointer to the tensor whose data will be retrieved.
1259 * @param data Pointer to the destination buffer where the tensor data will be
1260 * copied.
1261 * @param offset Offset in the destination buffer where to start copying.
1262 * @param size Size of the data to be copied, in bytes.
1263 */
1264static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
1265                                                const ggml_tensor *   tensor,
1266                                                void *                data,
1267                                                size_t                offset,
1268                                                size_t                size) {
1269    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1270
1271    ggml_cann_set_device(ctx->device);
1272
1273    if (!need_transform(tensor->type)) {
1274        ACL_CHECK(aclrtMemcpy(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1275    } else {
1276        void * transform_buffer = malloc(size);
1277        ACL_CHECK(aclrtMemcpy(transform_buffer, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1278        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1279        free(transform_buffer);
1280    }
1281}
1282
1283/**
1284 * @brief Copy tensor data between CANN buffers if possible.
1285 *
1286 * This function copies tensor data between CANN buffers if the source and
1287 * destination buffers are CANN buffers and they meet the necessary conditions
1288 * (same device or devices can access each other).
1289 *
1290 * @param buffer The destination CANN buffer where the tensor data will be
1291 * copied.
1292 * @param src Pointer to the source tensor whose data will be copied.
1293 * @param dst Pointer to the destination tensor where the data will be copied.
1294 * @return true if the copy operation succeeded, false otherwise.
1295 */
1296static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
1297                                                const ggml_tensor *   src,
1298                                                ggml_tensor *         dst) {
1299    if (ggml_backend_buft_is_cann(src->buffer->buft)) {
1300        ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
1301        ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1302
1303        size_t memcpy_size = ggml_nbytes(src);
1304        // Same device.
1305        if (src_ctx->device == dst_ctx->device) {
1306            ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1307                                  ACL_MEMCPY_DEVICE_TO_DEVICE));
1308            return true;
1309        } else {
1310#ifdef ASCEND_310P
1311            // TODO: Support 310p P2P copy
1312            return false;
1313#endif
1314            // Different device but can access by peer.
1315            int32_t canAccessPeer = 0;
1316            ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, dst_ctx->device));
1317            if (canAccessPeer) {
1318                ggml_cann_set_device(src_ctx->device);
1319                ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
1320                ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1321                                      ACL_MEMCPY_DEVICE_TO_DEVICE));
1322                return true;
1323            }
1324        }
1325    }
1326    return false;
1327}
1328
1329/**
1330 * @brief Clear a CANN buffer by setting all its memory to a specified value.
1331 *
1332 * This function clears a CANN buffer by setting all its memory to a specified
1333 * value.
1334 *
1335 * @param buffer The CANN buffer to be cleared.
1336 * @param value The value to which each byte in the buffer will be set.
1337 */
1338static void ggml_backend_cann_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1339    ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1340
1341    ggml_cann_set_device(ctx->device);
1342    ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
1343}
1344
1345/**
1346 * @brief Interface for a CANN buffer in the backend.
1347 *
1348 * This structure defines function pointers to operations that can be performed
1349 * on a CANN buffer within the backend.
1350 */
1351static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1352    /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,
1353    /* .get_base        = */ ggml_backend_cann_buffer_get_base,
1354    /* .init_tensor     = */ ggml_backend_cann_buffer_init_tensor,
1355    /* .memset_tensor   = */ NULL,
1356    /* .set_tensor      = */ ggml_backend_cann_buffer_set_tensor,
1357    /* .get_tensor      = */ ggml_backend_cann_buffer_get_tensor,
1358    /* .cpy_tensor      = */ ggml_backend_cann_buffer_cpy_tensor,
1359    /* .clear           = */ ggml_backend_cann_buffer_clear,
1360    /* .reset           = */ NULL,
1361};
1362
1363/**
1364 * @brief Allocates a new CANN buffer of the specified type and size.
1365 *
1366 * This function allocates a new CANN buffer on the specified device with the
1367 * given size.
1368 *
1369 * @param buft Pointer to the buffer type context.
1370 * @param size Size in bytes of the buffer to allocate.
1371 * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1372 */
1373static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1374    ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
1375
1376    ggml_cann_set_device(buft_ctx->device);
1377
1378    const size_t alignment = 128;
1379    size                   = GGML_PAD(size, alignment);
1380    if (size == 0) {
1381        size = alignment;
1382    }
1383    void *   dev_ptr;
1384    aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1385    if (err != ACL_SUCCESS) {
1386        GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", __func__,
1387                       size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg());
1388        return nullptr;
1389    }
1390
1391    ggml_backend_cann_buffer_context * ctx = new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1392
1393    return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size);
1394}
1395
1396/**
1397 * @brief Retrieves the memory alignment requirement for CANN buffers of this
1398 * type.
1399 *
1400 * This function returns the alignment requirement in bytes for memory allocated
1401 * by the CANN buffer type.
1402 *
1403 * @param buft Pointer to the buffer type context (unused in this
1404 * implementation).
1405 * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1406 * buffers).
1407 */
1408static size_t ggml_backend_cann_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1409    return 128;
1410
1411    GGML_UNUSED(buft);
1412}
1413
1414/**
1415 * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1416 *
1417 * Computes the total allocation size needed for storing the tensor's data in a
1418 * CANN buffer, considering any necessary padding or adjustments for quantized
1419 * types.
1420 *
1421 * @param buft Pointer to the buffer type context (unused in this
1422 * implementation).
1423 * @param tensor Pointer to the tensor for which the allocation size is
1424 * calculated.
1425 * @return The total allocation size in bytes required for the tensor in the
1426 * CANN buffer.
1427 */
1428static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
1429                                                           const ggml_tensor *        tensor) {
1430    size_t  size = ggml_nbytes(tensor);
1431    int64_t ne0  = tensor->ne[0];
1432
1433    // Only check env once.
1434    static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1435
1436    // last line must bigger than 32, because every single op deal at
1437    // least 32 bytes.
1438    // TODO: quantized type?
1439    // int64_t line_size = ne0 * ggml_element_size(tensor);
1440    // int64_t line_size_align_32 = (line_size + 31) & ~31;
1441    // size += (line_size_align_32 - line_size);
1442    if (ggml_is_quantized(tensor->type)) {
1443        if (ne0 % MATRIX_ROW_PADDING != 0) {
1444            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1445        }
1446    } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1447        // NZ format weight are not support quantized yet.
1448        // If ND tensor transform to NZ, size may changed.
1449        int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
1450        GGML_ASSERT(tensor->ne[2] == 1);
1451        GGML_ASSERT(tensor->ne[3] == 1);
1452        const aclIntArray * acl_shape = aclCreateIntArray(shape, 2);
1453        size_t              new_size;
1454        ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, ggml_cann_type_mapping(tensor->type), &new_size));
1455        ACL_CHECK(aclDestroyIntArray(acl_shape));
1456        size = std::max(size, new_size);
1457    }
1458
1459    return size;
1460
1461    GGML_UNUSED(buft);
1462}
1463
1464static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1465    return false;
1466
1467    GGML_UNUSED(buft);
1468}
1469
1470/**
1471 * @brief Interface for managing CANN buffer types in the GGML backend.
1472 *
1473 * Provides function pointers for allocating, querying properties, and managing
1474 * memory for CANN buffer types in the GGML backend.
1475 */
1476static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1477    /* .get_name         = */ ggml_backend_cann_buffer_type_name,
1478    /* .alloc_buffer     = */ ggml_backend_cann_buffer_type_alloc_buffer,
1479    /* .get_alignment    = */ ggml_backend_cann_buffer_type_get_alignment,
1480    /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX
1481    /* .get_alloc_size   = */ ggml_backend_cann_buffer_type_get_alloc_size,
1482    /* .is_host          = */ ggml_backend_cann_buffer_type_is_host,
1483};
1484
1485/**
1486 * @brief Retrieves the CANN buffer type for a specified device.
1487 *
1488 * This function initializes and returns the buffer type interface associated
1489 * with the given device. It ensures thread-safe access using a mutex.
1490 *
1491 * @param device The device index for which to retrieve the buffer type.
1492 * @return A pointer to the buffer type interface for the specified device, or
1493 * nullptr if the device index is out of range.
1494 */
1495ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) {
1496    static std::mutex           mutex;
1497    std::lock_guard<std::mutex> lock(mutex);
1498
1499    if (device >= ggml_backend_cann_get_device_count()) {
1500        return nullptr;
1501    }
1502
1503    static ggml_backend_buffer_type ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1504
1505    static bool ggml_backend_cann_buffer_type_initialized = false;
1506
1507    if (!ggml_backend_cann_buffer_type_initialized) {
1508        for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {
1509            ggml_backend_cann_buffer_types[i] = {
1510                /* .iface    = */ ggml_backend_cann_buffer_type_interface,
1511                /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
1512                /* .context  = */
1513                new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i) },
1514            };
1515        }
1516        ggml_backend_cann_buffer_type_initialized = true;
1517    }
1518
1519    return &ggml_backend_cann_buffer_types[device];
1520}
1521
1522/**
1523 * @brief Retrieves the name associated with a CANN host buffer type.
1524 *
1525 * This function returns the descriptive name associated with the specified
1526 * CANN host buffer type context.
1527 *
1528 * @param buft Pointer to the host buffer type context.
1529 * @return Const pointer to the C-style string containing the name.
1530 */
1531static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1532    return "CANN_Host";
1533
1534    GGML_UNUSED(buft);
1535}
1536
1537/**
1538 * @brief Retrieves the name associated with a CANN host buffer.
1539 *
1540 * This function returns the descriptive name associated with the specified
1541 * CANN host buffer context.
1542 *
1543 * @param buft Pointer to the host buffer context.
1544 * @return Const pointer to the C-style string containing the name.
1545 */
1546static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
1547    return "CANN_Host";
1548
1549    GGML_UNUSED(buffer);
1550}
1551
1552/**
1553 * @brief Free resources associated with a CANN host buffer.
1554 *
1555 * This function frees the resources associated with a CANN host buffer, including
1556 * its context.
1557 *
1558 * @param buffer The CANN host buffer to free.
1559 */
1560static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
1561    ACL_CHECK(aclrtFreeHost(buffer->context));
1562}
1563
1564/**
1565 * @brief Allocates a new CANN host buffer of the specified size.
1566 *
1567 * This function allocates a new CANN host buffer with the given size.
1568 * @param size Size in bytes of the host buffer to allocate.
1569 * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
1570 */
1571static void * ggml_cann_host_malloc(size_t size) {
1572    if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
1573        return nullptr;
1574    }
1575
1576    const size_t alignment = 128;
1577    size                   = GGML_PAD(size, alignment);
1578    if (size == 0) {
1579        size = alignment;
1580    }
1581
1582    void *   hostPtr = nullptr;
1583    aclError err     = aclrtMallocHost((void **) &hostPtr, size);
1584    if (err != ACL_SUCCESS) {
1585        GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0,
1586                      aclGetRecentErrMsg());
1587        return nullptr;
1588    }
1589    return hostPtr;
1590}
1591
1592/**
1593 * @brief Allocates a new CANN host buffer of the specified type and size.
1594 *
1595 * @param buft Pointer to the host buffer type context.
1596 * @param size Size in bytes of the host buffer to allocate.
1597 * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1598 */
1599static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1600                                                                             size_t                     size) {
1601    void * hostPtr = ggml_cann_host_malloc(size);
1602
1603    if (hostPtr == nullptr) {
1604        // fallback to cpu buffer
1605        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1606    }
1607
1608    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1609    buffer->buft                 = buft;
1610    buffer->iface.free_buffer    = ggml_backend_cann_host_buffer_free;
1611
1612    return buffer;
1613}
1614
1615/**
1616 * @brief Interface for managing CANN host buffer types in the GGML backend.
1617 *
1618 * Provides function pointers for allocating, querying properties, and managing
1619 * memory for CANN buffer types in the GGML backend.
1620 */
1621ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1622    static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1623        /* .iface    = */ {
1624                           /* .get_name         = */ ggml_backend_cann_host_buffer_type_name,
1625                           /* .alloc_buffer     = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1626                           /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1627                           /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX
1628            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1629                           /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1630                           },
1631        /* .device   = */
1632        ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1633        /* .context  = */ nullptr,
1634    };
1635
1636    return &ggml_backend_cann_buffer_type_host;
1637}
1638
1639/**
1640 * @brief Computes the forward operation for a given tensor using CANN
1641 * operations.
1642 *
1643 * This function selects the appropriate CANN operation based on the type of
1644 * operation specified in the tensor and performs the computation.
1645 *
1646 * @param ctx The CANN context containing necessary resources and
1647 * configurations.
1648 * @param dst The destination tensor where the result of the computation will be
1649 * stored.
1650 * @return true if the computation was successful; false otherwise.
1651 */
1652static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct ggml_tensor * dst) {
1653    switch (dst->op) {
1654        case GGML_OP_REPEAT:
1655            ggml_cann_repeat(ctx, dst);
1656            break;
1657        case GGML_OP_GET_ROWS:
1658            ggml_cann_get_rows(ctx, dst);
1659            break;
1660        case GGML_OP_SET_ROWS:
1661            ggml_cann_set_rows(ctx, dst);
1662            break;
1663        case GGML_OP_DUP:
1664            ggml_cann_dup(ctx, dst);
1665            break;
1666        case GGML_OP_ADD:
1667        case GGML_OP_ADD1:
1668            ggml_cann_binary_op<aclnn_add>(ctx, dst);
1669            break;
1670        case GGML_OP_SUB:
1671            ggml_cann_binary_op<aclnn_sub>(ctx, dst);
1672            break;
1673        case GGML_OP_ACC:
1674            ggml_cann_acc(ctx, dst);
1675            break;
1676        case GGML_OP_MUL:
1677            ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1678            break;
1679        case GGML_OP_DIV:
1680            ggml_cann_binary_op<aclnn_div>(ctx, dst);
1681            break;
1682        case GGML_OP_UNARY:
1683            switch (ggml_get_unary_op(dst)) {
1684                case GGML_UNARY_OP_ABS:
1685                    GGML_CANN_CALL_OP_UNARY(Abs);
1686                    break;
1687                case GGML_UNARY_OP_NEG:
1688                    GGML_CANN_CALL_OP_UNARY(Neg);
1689                    break;
1690                case GGML_UNARY_OP_GELU:
1691                case GGML_UNARY_OP_GELU_ERF:
1692                    // aclnnGelu internally uses the erf-based approximation.
1693                    GGML_CANN_CALL_OP_UNARY(Gelu);
1694                    break;
1695                case GGML_UNARY_OP_SILU:
1696                    GGML_CANN_CALL_OP_UNARY(Silu);
1697                    break;
1698                case GGML_UNARY_OP_GELU_QUICK:
1699                    {
1700                        auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1701                            GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1702                        };
1703                        ggml_cann_op_unary(lambda, ctx, dst);
1704                    }
1705                    break;
1706                case GGML_UNARY_OP_TANH:
1707                    GGML_CANN_CALL_OP_UNARY(Tanh);
1708                    break;
1709                case GGML_UNARY_OP_RELU:
1710                    GGML_CANN_CALL_OP_UNARY(Relu);
1711                    break;
1712                case GGML_UNARY_OP_SIGMOID:
1713                    GGML_CANN_CALL_OP_UNARY(Sigmoid);
1714                    break;
1715                case GGML_UNARY_OP_HARDSIGMOID:
1716                    GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
1717                    break;
1718                case GGML_UNARY_OP_HARDSWISH:
1719                    GGML_CANN_CALL_OP_UNARY(Hardswish);
1720                    break;
1721                case GGML_UNARY_OP_EXP:
1722                    GGML_CANN_CALL_OP_UNARY(Exp);
1723                    break;
1724                case GGML_UNARY_OP_ELU:
1725                    ggml_cann_elu(ctx, dst);
1726                    break;
1727                case GGML_UNARY_OP_SGN:
1728                    GGML_CANN_CALL_OP_UNARY(Sign);
1729                    break;
1730                case GGML_UNARY_OP_STEP:
1731                    ggml_cann_step(ctx, dst);
1732                    break;
1733                default:
1734                    return false;
1735            }
1736            break;
1737        case GGML_OP_GLU:
1738            switch (ggml_get_glu_op(dst)) {
1739                case GGML_GLU_OP_REGLU:
1740                    GGML_CANN_CALL_OP_UNARY_GATED(Relu);
1741                    break;
1742                case GGML_GLU_OP_GEGLU:
1743                case GGML_GLU_OP_GEGLU_ERF:
1744                    // aclnnGelu internally uses the erf-based approximation.
1745                    GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
1746                    break;
1747                case GGML_GLU_OP_SWIGLU:
1748                    GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1749                    break;
1750                case GGML_GLU_OP_GEGLU_QUICK:
1751                    {
1752                        auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1753                            GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1754                        };
1755                        ggml_cann_op_unary_gated(lambda, ctx, dst);
1756                    }
1757                    break;
1758                default:
1759                    return false;
1760            }
1761            break;
1762        case GGML_OP_NORM:
1763            ggml_cann_norm(ctx, dst);
1764            break;
1765        case GGML_OP_GROUP_NORM:
1766            ggml_cann_group_norm(ctx, dst);
1767            break;
1768        case GGML_OP_L2_NORM:
1769            ggml_cann_l2_norm(ctx, dst);
1770            break;
1771        case GGML_OP_CROSS_ENTROPY_LOSS:
1772            ggml_cann_cross_entropy_loss(ctx, dst);
1773            break;
1774        case GGML_OP_CONCAT:
1775            ggml_cann_concat(ctx, dst);
1776            break;
1777        case GGML_OP_UPSCALE:
1778            ggml_cann_upsample_nearest2d(ctx, dst);
1779            break;
1780        case GGML_OP_PAD:
1781            ggml_cann_pad(ctx, dst);
1782            break;
1783        case GGML_OP_ARANGE:
1784            ggml_cann_arange(ctx, dst);
1785            break;
1786        case GGML_OP_TIMESTEP_EMBEDDING:
1787            ggml_cann_timestep_embedding(ctx, dst);
1788            break;
1789        case GGML_OP_LEAKY_RELU:
1790            ggml_cann_leaky_relu(ctx, dst);
1791            break;
1792        case GGML_OP_RMS_NORM:
1793            ggml_cann_rms_norm(ctx, dst);
1794            break;
1795        case GGML_OP_MUL_MAT:
1796            ggml_cann_mul_mat(ctx, dst);
1797            break;
1798        case GGML_OP_MUL_MAT_ID:
1799            ggml_cann_mul_mat_id(ctx, dst);
1800            break;
1801        case GGML_OP_SCALE:
1802            ggml_cann_scale(ctx, dst);
1803            break;
1804        case GGML_OP_SQR:
1805            GGML_ASSERT(dst->src[1] == nullptr);
1806            dst->src[1] = dst->src[0];
1807            ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1808            break;
1809        case GGML_OP_SQRT:
1810            GGML_CANN_CALL_OP_UNARY(Sqrt);
1811            break;
1812        case GGML_OP_CLAMP:
1813            ggml_cann_clamp(ctx, dst);
1814            break;
1815        case GGML_OP_CPY:
1816            ggml_cann_cpy(ctx, dst);
1817            break;
1818        case GGML_OP_CONT:
1819            ggml_cann_dup(ctx, dst);
1820            break;
1821        case GGML_OP_NONE:
1822        case GGML_OP_RESHAPE:
1823        case GGML_OP_VIEW:
1824        case GGML_OP_PERMUTE:
1825        case GGML_OP_TRANSPOSE:
1826            break;
1827        case GGML_OP_DIAG_MASK_INF:
1828            ggml_cann_diag_mask(ctx, dst, -INFINITY);
1829            break;
1830        case GGML_OP_SOFT_MAX:
1831            ggml_cann_softmax(ctx, dst);
1832            break;
1833        case GGML_OP_ROPE:
1834            ggml_cann_rope(ctx, dst);
1835            break;
1836        case GGML_OP_IM2COL:
1837            ggml_cann_im2col(ctx, dst);
1838            break;
1839        case GGML_OP_POOL_2D:
1840            ggml_cann_pool2d(ctx, dst);
1841            break;
1842        case GGML_OP_SUM:
1843            ggml_cann_sum(ctx, dst);
1844            break;
1845        case GGML_OP_SUM_ROWS:
1846            ggml_cann_sum_rows(ctx, dst);
1847            break;
1848        case GGML_OP_ARGSORT:
1849            ggml_cann_argsort(ctx, dst);
1850            break;
1851        case GGML_OP_ARGMAX:
1852            ggml_cann_argmax(ctx, dst);
1853            break;
1854        case GGML_OP_COS:
1855            ggml_cann_op_unary<aclnn_cos>(ctx, dst);
1856            break;
1857        case GGML_OP_SIN:
1858            ggml_cann_op_unary<aclnn_sin>(ctx, dst);
1859            break;
1860        case GGML_OP_CONV_TRANSPOSE_1D:
1861            ggml_cann_conv_transpose_1d(ctx, dst);
1862            break;
1863        case GGML_OP_LOG:
1864            GGML_CANN_CALL_OP_UNARY(Log);
1865            break;
1866        case GGML_OP_MEAN:
1867            ggml_cann_mean(ctx, dst);
1868            break;
1869        case GGML_OP_PAD_REFLECT_1D:
1870            ggml_cann_pad_reflect_1d(ctx, dst);
1871            break;
1872        case GGML_OP_COUNT_EQUAL:
1873            ggml_cann_count_equal(ctx, dst);
1874            break;
1875        case GGML_OP_FLASH_ATTN_EXT:
1876            ggml_cann_flash_attn_ext(ctx, dst);
1877            break;
1878        case GGML_OP_OUT_PROD:
1879            ggml_cann_out_prod(ctx, dst);
1880            break;
1881        case GGML_OP_GATED_LINEAR_ATTN:
1882            ggml_cann_gated_linear_attn(ctx, dst);
1883            break;
1884        case GGML_OP_SSM_CONV:
1885            ggml_cann_ssm_conv(ctx, dst);
1886            break;
1887        default:
1888            return false;
1889    }
1890
1891    return true;
1892}
1893
1894// backend
1895/**
1896 * @brief Retrieves the name associated with the CANN backend.
1897 *
1898 * This function returns the name assigned to the CANN backend, which is stored
1899 * in the context of the provided backend structure.
1900 *
1901 * @param backend Pointer to the CANN backend structure.
1902 * @return A pointer to a constant string representing the backend name.
1903 */
1904static const char * ggml_backend_cann_name(ggml_backend_t backend) {
1905    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1906
1907    return cann_ctx->name.c_str();
1908}
1909
1910/**
1911 * @brief Frees resources associated with the CANN backend.
1912 *
1913 * This function releases resources associated with the CANN backend context
1914 * and resets the device associated with the backend to its initial state.
1915 *
1916 * @param backend Pointer to the CANN backend structure to be freed.
1917 */
1918static void ggml_backend_cann_free(ggml_backend_t backend) {
1919    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1920    ACL_CHECK(aclrtSynchronizeDevice());
1921    ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1922
1923    delete cann_ctx;
1924    delete backend;
1925}
1926
1927/**
1928 * @brief Sets tensor data asynchronously in the CANN backend.
1929 *
1930 * This function asynchronously sets tensor data in the CANN backend.
1931 *
1932 * @param backend Pointer to the CANN backend structure.
1933 * @param tensor Pointer to the tensor structure to set data for.
1934 * @param data Pointer to the host data to copy to the tensor.
1935 * @param offset Offset in bytes within the host data.
1936 * @param size Size of the data to copy in bytes.
1937 */
1938static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1939                                               ggml_tensor *  tensor,
1940                                               const void *   data,
1941                                               size_t         offset,
1942                                               size_t         size) {
1943    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1944    ggml_backend_buffer_t       buf      = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1945
1946    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
1947    GGML_ASSERT(!ggml_is_quantized(tensor->type));
1948
1949    ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
1950                               cann_ctx->stream()));
1951}
1952
1953/**
1954 * @brief Gets tensor data asynchronously in the CANN backend.
1955 *
1956 * This function asynchronously gets tensor data in the CANN backend.
1957 *
1958 * @param backend Pointer to the CANN backend structure.
1959 * @param tensor Pointer to the tensor structure to get data from.
1960 * @param data Pointer to the host data to copy from the tensor.
1961 * @param offset Offset in bytes within the host data.
1962 * @param size Size of the data to copy in bytes.
1963 */
1964static void ggml_backend_cann_get_tensor_async(ggml_backend_t      backend,
1965                                               const ggml_tensor * tensor,
1966                                               void *              data,
1967                                               size_t              offset,
1968                                               size_t              size) {
1969    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1970    ggml_backend_buffer_t       buf      = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1971
1972    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
1973    GGML_ASSERT(!ggml_is_quantized(tensor->type));
1974
1975    ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
1976                               cann_ctx->stream()));
1977}
1978
1979/**
1980 * @brief Asynchronously copies tensor data between CANN backends.
1981 *
1982 * This function copies tensor data asynchronously between two CANN backends. It
1983 * checks if both tensors reside in CANN buffers and whether the devices support
1984 * peer-to-peer access for direct copying. If not, it returns false.
1985 *
1986 * @param backend_src Pointer to the source CANN backend structure.
1987 * @param backend_dst Pointer to the destination CANN backend structure.
1988 * @param src Pointer to the source tensor to copy data from.
1989 * @param dst Pointer to the destination tensor to copy data to.
1990 * @return true if the copy operation succeeds, false otherwise.
1991 */
1992static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t      backend_src,
1993                                               ggml_backend_t      backend_dst,
1994                                               const ggml_tensor * src,
1995                                               ggml_tensor *       dst) {
1996    GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst));
1997
1998    GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
1999
2000    if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
2001        return false;
2002    }
2003
2004    ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2005    ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2006
2007    ggml_backend_cann_context * cann_ctx_src = (ggml_backend_cann_context *) backend_src->context;
2008    ggml_backend_cann_context * cann_ctx_dst = (ggml_backend_cann_context *) backend_dst->context;
2009
2010    size_t copy_size = ggml_nbytes(dst);
2011    if (copy_size == 0) {
2012        return true;
2013    }
2014    if (backend_src != backend_dst) {
2015#ifdef ASCEND_310P
2016        // TODO: Support 310p P2P copy
2017        return false;
2018#endif
2019        ggml_backend_cann_buffer_context * buf_ctx_src = (ggml_backend_cann_buffer_context *) buf_src->context;
2020        ggml_backend_cann_buffer_context * buf_ctx_dst = (ggml_backend_cann_buffer_context *) buf_dst->context;
2021
2022        GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
2023        GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
2024
2025        int32_t canAccessPeer = 0;
2026        ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, cann_ctx_dst->device));
2027        if (!canAccessPeer) {
2028            return false;
2029        }
2030
2031        // need open both directions for memcpyasync between devices.
2032        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
2033        ggml_cann_set_device(cann_ctx_src->device);
2034        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
2035
2036        // wait for task_queue empty to keep task order.
2037        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2038                                   cann_ctx_src->stream()));
2039        // record event on src stream after the copy
2040        // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
2041        // if (!cann_ctx_src->copy_event) {
2042        //     ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
2043        // }
2044        // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
2045
2046        // // wait on dst stream for the copy to complete
2047        // ggml_cann_set_device(cann_ctx_dst->device);
2048        // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
2049        ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
2050    } else {
2051        // src and dst are on the same backend
2052        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2053                                   cann_ctx_dst->stream()));
2054    }
2055
2056    return true;
2057}
2058
2059/**
2060 * @brief Synchronizes a CANN backend.
2061 *
2062 * This function synchronizes the specified CANN backend by waiting for all
2063 * operations in its associated stream to complete.
2064 *
2065 * @param backend Pointer to the CANN backend structure to synchronize.
2066 */
2067static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
2068    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2069    ggml_cann_set_device(cann_ctx->device);
2070    ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
2071}
2072
2073/**
2074 * @brief Check if CANN backend can fuse the specified operation sequence
2075 *
2076 * This function determines whether an operation sequence starting from the specified node
2077 * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
2078 * memory access overhead and improve computational efficiency.
2079 *
2080 * @param cgraph Pointer to the computation graph
2081 * @param node_idx Index of the starting node in the computation graph
2082 * @param ops Sequence of operation types to check for fusion
2083 * @return true if the operations can be fused
2084 * @return false if the operations cannot be fused
2085 */
2086static bool ggml_cann_can_fuse(const struct ggml_cgraph *          cgraph,
2087                               int                                 node_idx,
2088                               std::initializer_list<enum ggml_op> ops) {
2089    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2090        return false;
2091    }
2092
2093    // CANN backend supports fusing ADD + RMS_NORM operations
2094    if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
2095        ggml_tensor * add_node = cgraph->nodes[node_idx];
2096        // TODO: support broadcast for ADD + RMS_NORM
2097        if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
2098            add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
2099            return false;
2100        }
2101        return true;
2102    }
2103
2104    return false;
2105}
2106
2107/**
2108 * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
2109 *
2110 * If CANN graph execution is enabled and graph capture is required, this function begins
2111 * graph capture, runs the graph, ends capture, and stores the captured graph.
2112 *
2113 * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
2114 *
2115 * @param cann_ctx                     The CANN backend context.
2116 * @param cgraph                       The ggml computation graph.
2117 * @param use_cann_graph               Whether to use CANN graph execution.
2118 * @param cann_graph_capture_required  Whether graph capture is needed due to graph changes.
2119 */
2120static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
2121                                            ggml_cgraph *               cgraph,
2122                                            bool                        use_cann_graph,
2123                                            bool                        cann_graph_capture_required) {
2124#ifdef USE_ACL_GRAPH
2125    if (use_cann_graph && cann_graph_capture_required) {  // Begin CANN graph capture
2126        ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2127    }
2128#endif  // USE_ACL_GRAPH
2129    // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
2130    // With the use of CANN graphs, the execution will be performed by the graph launch.
2131    static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
2132
2133    if (!use_cann_graph || cann_graph_capture_required) {
2134        for (int i = 0; i < cgraph->n_nodes; i++) {
2135            ggml_tensor * node = cgraph->nodes[i];
2136            if (opt_fusion) {
2137                if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
2138                    ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
2139                    i++;
2140                    continue;
2141                }
2142            }
2143
2144            if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
2145                node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2146                continue;
2147            }
2148
2149            if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2150                continue;
2151            }
2152
2153            bool ok = ggml_cann_compute_forward(*cann_ctx, node);
2154            if (!ok) {
2155                GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2156            }
2157            GGML_ASSERT(ok);
2158        }
2159    }
2160
2161#ifdef USE_ACL_GRAPH
2162    if (use_cann_graph) {
2163        GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
2164        ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2165
2166        if (cann_graph_capture_required) {  // End CANN graph capture
2167            ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2168        }
2169
2170        // Execute CANN graph
2171        ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
2172    }
2173#endif  // USE_ACL_GRAPH
2174}
2175
2176/**
2177 * @brief Computes a computational graph using a CANN backend.
2178 *
2179 * This function computes the operations defined in the computational graph
2180 * using the specified CANN backend.
2181 *
2182 * @param backend Pointer to the CANN backend structure to use for computation.
2183 * @param cgraph Pointer to the computational graph structure containing nodes
2184 *               representing operations to be computed.
2185 * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
2186 *         completes successfully, otherwise an appropriate error status.
2187 */
2188static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2189    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2190    ggml_cann_set_device(cann_ctx->device);
2191    g_nz_workspaces[cann_ctx->device].clear();
2192
2193    // calculate rope cache for fist layer in current device.
2194    cann_ctx->rope_cache.cached = false;
2195
2196    bool graph_capture_required = false;
2197#ifdef USE_ACL_GRAPH
2198    bool use_cann_graph = true;
2199
2200    static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2201    if (!prefill_use_graph) {
2202        // Do not use acl_graph for prefill.
2203        for (int i = 0; i < cgraph->n_nodes; i++) {
2204            ggml_tensor * node = cgraph->nodes[i];
2205            // TODO: Optimize here. Currently, we can only
2206            // get seq_len by FA's input.
2207            if (node->op == GGML_OP_FLASH_ATTN_EXT) {
2208                // Q -> src[0], shape: [B, S, N, D]
2209                use_cann_graph = (node->src[0]->ne[1] == 1);
2210                break;
2211            }
2212        }
2213    }
2214
2215    if (!cann_ctx->acl_graph_mode) {
2216        use_cann_graph = false;
2217    }
2218
2219    if (use_cann_graph) {
2220        // If no matching graph is found, the graph needs to be recaptured.
2221        graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
2222        if (graph_capture_required) {
2223            // If no matching graph is found, add a new ACL graph.
2224            ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
2225            cann_ctx->graph_lru_cache.push(new_graph);
2226        }
2227    }
2228#else
2229    bool use_cann_graph = false;
2230#endif  // USE_ACL_GRAPH
2231    evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
2232
2233    return GGML_STATUS_SUCCESS;
2234}
2235
2236/**
2237 * @brief Checks if the CANN backend supports a specific operation.
2238 *
2239 * This function checks whether the specified operation is supported by the
2240 * CANN backend.
2241 *
2242 * @param backend Pointer to the CANN backend structure to check support for
2243 *                the operation.
2244 * @param op Pointer to the tensor representing the operation to check.
2245 * @return bool Returns true if the operation is supported by the backend,
2246 *              otherwise false.
2247 */
2248static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2249    switch (op->op) {
2250        case GGML_OP_UNARY:
2251            switch (ggml_get_unary_op(op)) {
2252                case GGML_UNARY_OP_ABS:
2253                case GGML_UNARY_OP_NEG:
2254                case GGML_UNARY_OP_GELU:
2255                case GGML_UNARY_OP_SILU:
2256                case GGML_UNARY_OP_RELU:
2257                case GGML_UNARY_OP_SIGMOID:
2258                case GGML_UNARY_OP_HARDSIGMOID:
2259                case GGML_UNARY_OP_HARDSWISH:
2260                case GGML_UNARY_OP_GELU_QUICK:
2261                case GGML_UNARY_OP_TANH:
2262                case GGML_UNARY_OP_EXP:
2263                case GGML_UNARY_OP_ELU:
2264                case GGML_UNARY_OP_SGN:
2265                case GGML_UNARY_OP_STEP:
2266                case GGML_UNARY_OP_GELU_ERF:
2267                    return true;
2268                default:
2269                    return false;
2270            }
2271        case GGML_OP_GLU:
2272            switch (ggml_get_glu_op(op)) {
2273                case GGML_GLU_OP_REGLU:
2274                case GGML_GLU_OP_GEGLU:
2275                case GGML_GLU_OP_SWIGLU:
2276                case GGML_GLU_OP_GEGLU_ERF:
2277                case GGML_GLU_OP_GEGLU_QUICK:
2278                    return true;
2279                default:
2280                    return false;
2281            }
2282            break;
2283        case GGML_OP_MUL_MAT:
2284            {
2285                switch (op->src[0]->type) {
2286                    case GGML_TYPE_F16:
2287                    case GGML_TYPE_F32:
2288                        return true;
2289                    case GGML_TYPE_Q8_0:
2290                    case GGML_TYPE_Q4_0:
2291#ifdef ASCEND_310P
2292                        // Q4 && Q8 per group is not support on 310p device
2293                        return false;
2294#endif
2295                        // only support contiguous for quantized types.
2296                        return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2297                    default:
2298                        return false;
2299                }
2300            }
2301        case GGML_OP_MUL_MAT_ID:
2302            switch (op->src[0]->type) {
2303                case GGML_TYPE_F16:
2304                case GGML_TYPE_F32:
2305                    return true;
2306                case GGML_TYPE_Q8_0:
2307                case GGML_TYPE_Q4_0:
2308#ifdef ASCEND_310P
2309                    // Q4 && Q8 per group is not support on 310p device
2310                    return false;
2311#endif
2312                    // only support contiguous for quantized types.
2313                    return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2314                default:
2315                    return false;
2316            }
2317        // embedding
2318        case GGML_OP_GET_ROWS:
2319            {
2320                switch (op->src[0]->type) {
2321                    case GGML_TYPE_F32:
2322                    case GGML_TYPE_F16:
2323                    case GGML_TYPE_Q8_0:
2324                        return true;
2325                    default:
2326                        return false;
2327                }
2328            }
2329            break;
2330        case GGML_OP_SET_ROWS:
2331            {
2332                switch (op->type) {
2333                    case GGML_TYPE_F32:
2334                    case GGML_TYPE_F16:
2335                        return true;
2336                    default:
2337                        return false;
2338                }
2339            }
2340            break;
2341        case GGML_OP_CPY:
2342            {
2343                ggml_tensor * src = op->src[0];
2344                if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2345                    (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
2346                    // only support F32 and F16.
2347                    return false;
2348                }
2349                return true;
2350            }
2351            break;
2352        case GGML_OP_CONT:
2353            {
2354                // TODO: support GGML_TYPE_BF16
2355                switch (op->src[0]->type) {
2356                    case GGML_TYPE_F32:
2357                    case GGML_TYPE_F16:
2358                        return true;
2359                    default:
2360                        return false;
2361                }
2362            }
2363        case GGML_OP_ROPE:
2364            {
2365                if (op->src[0]->ne[0] > 896) {
2366                    return false;
2367                }
2368#ifdef ASCEND_310P
2369                // TODO: Support rope_dim < ne00(dim)
2370                if (op->src[0]->ne[0] != op->op_params[1]) {
2371                    return false;
2372                }
2373                if (!ggml_is_contiguous(op->src[0])) {
2374                    return false;
2375                }
2376#endif
2377                return true;
2378            }
2379        case GGML_OP_UPSCALE:
2380            {
2381                // aclnnUpsampleNearest2dGetWorkspaceSize not support
2382                // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
2383                if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
2384                    return false;
2385                }
2386                if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2387                    return false;
2388                }
2389                if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
2390                    return false;
2391                }
2392                return true;
2393            }
2394        case GGML_OP_POOL_2D:
2395            {
2396                const int32_t * opts = (const int32_t *) op->op_params;
2397#ifdef ASCEND_310P
2398                enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2399                if (opt == GGML_OP_POOL_MAX) {
2400                    return false;
2401                }
2402#endif
2403                const int k0 = opts[1];
2404                const int k1 = opts[2];
2405                const int p0 = opts[5];
2406                const int p1 = opts[6];
2407                // value of paddingH should be at most half of kernelH
2408                // value of paddingW should be at most half of kernelW
2409                return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2410            }
2411        case GGML_OP_SUM:
2412            return ggml_is_contiguous_rows(op->src[0]);
2413        case GGML_OP_L2_NORM:
2414        case GGML_OP_CROSS_ENTROPY_LOSS:
2415        case GGML_OP_DUP:
2416        case GGML_OP_IM2COL:
2417        case GGML_OP_CONCAT:
2418        case GGML_OP_REPEAT:
2419        case GGML_OP_NONE:
2420        case GGML_OP_RESHAPE:
2421        case GGML_OP_VIEW:
2422        case GGML_OP_PERMUTE:
2423        case GGML_OP_TRANSPOSE:
2424        case GGML_OP_NORM:
2425        case GGML_OP_ADD:
2426        case GGML_OP_ADD1:
2427        case GGML_OP_SUB:
2428        case GGML_OP_MUL:
2429        case GGML_OP_DIV:
2430        case GGML_OP_RMS_NORM:
2431        case GGML_OP_SQR:
2432        case GGML_OP_SQRT:
2433        case GGML_OP_CLAMP:
2434        case GGML_OP_DIAG_MASK_INF:
2435        case GGML_OP_SUM_ROWS:
2436        case GGML_OP_ARGSORT:
2437        case GGML_OP_ACC:
2438        case GGML_OP_GROUP_NORM:
2439            return true;
2440        case GGML_OP_PAD:
2441            // TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
2442            return ggml_get_op_params_i32(op, 8) == 0;
2443        case GGML_OP_ARANGE:
2444        case GGML_OP_TIMESTEP_EMBEDDING:
2445        case GGML_OP_LEAKY_RELU:
2446        case GGML_OP_ARGMAX:
2447        case GGML_OP_COS:
2448        case GGML_OP_SIN:
2449        case GGML_OP_LOG:
2450        case GGML_OP_MEAN:
2451        case GGML_OP_PAD_REFLECT_1D:
2452        case GGML_OP_COUNT_EQUAL:
2453        case GGML_OP_GATED_LINEAR_ATTN:
2454            return true;
2455        case GGML_OP_OUT_PROD:
2456            {
2457#ifdef ASCEND_310P
2458                // Ger is not supported on 310p device
2459                return false;
2460#endif
2461                switch (op->src[0]->type) {
2462                    case GGML_TYPE_F16:
2463                    case GGML_TYPE_F32:
2464                        return true;
2465                    default:
2466                        return false;
2467                }
2468            }
2469        case GGML_OP_CONV_TRANSPOSE_1D:
2470            return true;
2471        case GGML_OP_SCALE:
2472            float bias;
2473            memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
2474            return bias == 0.0f;  // TODO: support bias != 0.0f
2475        case GGML_OP_SOFT_MAX:
2476            // TODO: support attention sinks [TAG_ATTN_SINKS]
2477            if (op->src[2]) {
2478                return false;
2479            }
2480            return true;
2481        case GGML_OP_FLASH_ATTN_EXT:
2482            {
2483#ifdef ASCEND_310P
2484                // FA not support on 310p device
2485                return false;
2486#endif
2487                // derived from [ggml-cuda.cu]
2488                if (op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16) {
2489                    return false;
2490                }
2491                if (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 &&
2492                    op->src[1]->type != GGML_TYPE_BF16) {
2493                    return false;
2494                }
2495                if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16) {
2496                    return false;
2497                }
2498                // TODO: support attention sinks [TAG_ATTN_SINKS]
2499                if (op->src[4]) {
2500                    return false;
2501                }
2502                if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2503                    // different head sizes of K and V are not supported yet
2504                    return false;
2505                }
2506                if (op->src[0]->ne[0] % 16 != 0) {
2507                    // TODO: padding to support
2508                    return false;
2509                }
2510                float logitSoftcap = 0.0f;
2511                memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
2512                if (logitSoftcap != 0.0f) {
2513                    return false;
2514                }
2515                return true;
2516            }
2517        case GGML_OP_SSM_CONV:
2518            return true;
2519        default:
2520            return false;
2521    }
2522
2523    GGML_UNUSED(dev);
2524}
2525
2526/**
2527 * @brief Records an event on the CANN backend stream.
2528 *
2529 * This function records the given event on the ACL runtime stream associated
2530 * with the backend context.
2531 *
2532 * @param event Pointer to the event structure to be recorded.
2533 */
2534static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
2535    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2536    ACL_CHECK(aclrtRecordEvent((aclrtEvent) event->context, cann_ctx->stream()));
2537}
2538
2539/**
2540 * @brief Waits for a recorded event to complete on the CANN backend stream.
2541 *
2542 * This function makes the given backend wait for the event to complete on its
2543 * ACL runtime stream.
2544 *
2545 * @param backend Pointer to the backend structure.
2546 * @param event Pointer to the event structure that the backend needs to wait
2547 * for.
2548 */
2549static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
2550    ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2551    if (ggml_backend_is_cann(backend)) {
2552        ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent) event->context));
2553    } else {
2554        GGML_ABORT("fatal error");
2555    }
2556}
2557
2558/**
2559 * @brief Structure defining the interface for the CANN backend.
2560 *
2561 * This structure contains function pointers for various operations
2562 * supported by the CANN backend, including name retrieval, memory
2563 * management, tensor operations, synchronization, and event handling.
2564 */
2565static const ggml_backend_i ggml_backend_cann_interface = {
2566    /* .get_name                = */ ggml_backend_cann_name,
2567    /* .free                    = */ ggml_backend_cann_free,
2568    /* .set_tensor_async        = */ ggml_backend_cann_set_tensor_async,
2569    /* .get_tensor_async        = */ ggml_backend_cann_get_tensor_async,
2570    /* .cpy_tensor_async        = */ ggml_backend_cann_cpy_tensor_async,
2571    /* .synchronize             = */ ggml_backend_cann_synchronize,
2572    /* .graph_plan_create       = */ NULL,
2573    /* .graph_plan_free         = */ NULL,
2574    /* .graph_plan_update       = */ NULL,
2575    /* .graph_plan_compute      = */ NULL,
2576    /* .graph_compute           = */ ggml_backend_cann_graph_compute,
2577    /* .event_record            = */ ggml_backend_cann_event_record,
2578    /* .event_wait              = */ ggml_backend_cann_event_wait,
2579    /* .graph_optimize          = */ NULL,
2580};
2581
2582/**
2583 * @brief Return the hardcoded GUID for the CANN backend.
2584 *
2585 * This function returns a static GUID which uniquely identifies the CANN
2586 * backend.
2587 *
2588 * @return A pointer to the static GUID.
2589 */
2590static ggml_guid_t ggml_backend_cann_guid() {
2591    static ggml_guid guid = { 0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
2592                              0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64 };
2593    return &guid;
2594}
2595
2596// backend device
2597struct ggml_backend_cann_device_context {
2598    int         device;
2599    std::string name;
2600    std::string description;
2601    int op_offload_min_batch_size;
2602};
2603
2604static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
2605    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2606    return ctx->name.c_str();
2607}
2608
2609static const char * ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
2610    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2611    return ctx->description.c_str();
2612}
2613
2614static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2615    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2616    ggml_backend_cann_get_device_memory(ctx->device, free, total);
2617}
2618
2619static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
2620    GGML_UNUSED(dev);
2621    return GGML_BACKEND_DEVICE_TYPE_GPU;
2622}
2623
2624static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
2625    props->name        = ggml_backend_cann_device_get_name(dev);
2626    props->description = ggml_backend_cann_device_get_description(dev);
2627    props->type        = ggml_backend_cann_device_get_type(dev);
2628    ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
2629
2630    bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
2631
2632    props->caps = {
2633        /* .async                 = */ false,
2634        /* .host_buffer           = */ host_buffer,
2635        /* .buffer_from_host_ptr  = */ false,
2636        /* .events                = */ true,
2637    };
2638}
2639
2640static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
2641    GGML_UNUSED(params);
2642    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2643    return ggml_backend_cann_init(ctx->device);
2644}
2645
2646/**
2647 * @brief Checks if the CANN backend supports a specific backend buffer type.
2648 *
2649 * This function determines whether the CANN backend supports the given backend
2650 * buffer type by comparing the device context of the backend and buffer type.
2651 * It returns true if the devices are same between the backend context and
2652 * buffer type context.
2653 *
2654 * @param backend Pointer to the CANN backend.
2655 * @param buft Pointer to the backend buffer type to check.
2656 * @return bool Returns true if the CANN backend supports the buffer type,
2657 *              otherwise false.
2658 */
2659static bool ggml_backend_cann_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2660    if (ggml_backend_buft_is_cann(buft)) {
2661        ggml_backend_cann_device_context *      dev_ctx  = (ggml_backend_cann_device_context *) dev->context;
2662        ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
2663        return buft_ctx->device == dev_ctx->device;
2664    }
2665    return false;
2666}
2667
2668static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
2669    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2670    return ggml_backend_cann_buffer_type(ctx->device);
2671}
2672
2673static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
2674    GGML_UNUSED(dev);
2675    return ggml_backend_cann_host_buffer_type();
2676}
2677
2678/**
2679 * @brief Determines if a tensor operation should be offloaded to the CANN
2680 * backend.
2681 *
2682 * This function checks if a given tensor operation should be offloaded to the
2683 * CANN backend based on the operation type and the size of the tensor. It
2684 * returns true if the second dimension (ne[1]) of the tensor is greater than or
2685 * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
2686 *
2687 * @param backend Pointer to the CANN backend.
2688 * @param op Pointer to the tensor operation to check.
2689 * @return bool Returns true if the operation should be offloaded, otherwise
2690 * false.
2691 */
2692static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2693    ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2694
2695    return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;
2696}
2697
2698/**
2699 * @brief Creates a new event for the CANN backend device.
2700 *
2701 * This function initializes a new event for the CANN backend by setting the
2702 * device and creating an ACL runtime event. The created event is then wrapped
2703 * in a ggml_backend_event structure and returned.
2704 *
2705 * @param backend Pointer to the CANN backend.
2706 * @return ggml_backend_event_t Returns a pointer to the new event structure.
2707 */
2708static ggml_backend_event_t ggml_backend_cann_device_event_new(ggml_backend_dev_t dev) {
2709    ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2710
2711    ggml_cann_set_device(dev_ctx->device);
2712
2713    aclrtEvent event;
2714    ACL_CHECK(aclrtCreateEvent(&event));
2715
2716    return new ggml_backend_event{
2717        /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
2718        /* .context = */ event,
2719    };
2720}
2721
2722/**
2723 * @brief Frees a CANN backend event.
2724 *
2725 * This function destroys the ACL runtime event associated with the given CANN
2726 * backend event and then deletes the event structure itself.
2727 *
2728 * @param event Pointer to the event structure to be freed.
2729 */
2730static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2731    ACL_CHECK(aclrtDestroyEvent((aclrtEvent) event->context));
2732
2733    delete event;
2734    GGML_UNUSED(dev);
2735}
2736
2737/**
2738 * @brief Synchronizes the given event on the CANN backend.
2739 *
2740 * This function waits for the specified event to complete on the ACL runtime.
2741 *
2742 * @param event Pointer to the event structure to be synchronized.
2743 */
2744static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2745    ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent) event->context));
2746
2747    GGML_UNUSED(dev);
2748}
2749
2750static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2751    /* .get_name                = */ ggml_backend_cann_device_get_name,
2752    /* .get_description         = */ ggml_backend_cann_device_get_description,
2753    /* .get_memory              = */ ggml_backend_cann_device_get_memory,
2754    /* .get_type                = */ ggml_backend_cann_device_get_type,
2755    /* .get_props               = */ ggml_backend_cann_device_get_props,
2756    /* .init_backend            = */ ggml_backend_cann_device_init,  // called for every card
2757    /* .get_buffer_type         = */ ggml_backend_cann_device_get_buffer_type,
2758    /* .get_host_buffer_type    = */ ggml_backend_cann_device_get_host_buffer_type,
2759    /* .buffer_from_host_ptr    = */ NULL,  // not supported for CANN
2760    /* .supports_op             = */ ggml_backend_cann_supports_op,
2761    /* .supports_buft           = */ ggml_backend_cann_supports_buft,
2762    /* .offload_op              = */ ggml_backend_cann_offload_op,
2763    /* .event_new               = */ ggml_backend_cann_device_event_new,
2764    /* .event_free              = */ ggml_backend_cann_device_event_free,
2765    /* .event_synchronize       = */ ggml_backend_cann_device_event_synchronize,
2766};
2767
2768// backend reg
2769struct ggml_backend_cann_reg_context {
2770    std::vector<ggml_backend_dev_t> devices;
2771};
2772
2773static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2774    GGML_UNUSED(reg);
2775    return GGML_CANN_NAME;
2776}
2777
2778static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2779    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2780    return ctx->devices.size();
2781}
2782
2783static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2784    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2785    GGML_ASSERT(index < ctx->devices.size());
2786    return ctx->devices[index];
2787}
2788
2789static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2790    GGML_UNUSED(reg);
2791    GGML_UNUSED(name);
2792    // reserved for future use
2793    return nullptr;
2794}
2795
2796static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2797    /* .get_name          = */ ggml_backend_cann_reg_get_name,
2798    /* .get_device_count  = */ ggml_backend_cann_reg_get_device_count,
2799    /* .get_device        = */ ggml_backend_cann_reg_get_device,
2800    /* .get_proc_address  = */ ggml_backend_cann_reg_get_proc_address,
2801};
2802
2803// backend registry, called only once for cann backend
2804ggml_backend_reg_t ggml_backend_cann_reg() {
2805    static ggml_backend_reg reg;
2806    static bool             initialized = false;
2807
2808    {
2809        static std::mutex           mutex;
2810        std::lock_guard<std::mutex> lock(mutex);
2811        if (!initialized) {
2812            aclInit(nullptr);
2813            ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2814            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
2815
2816            for (int i = 0; i < ggml_cann_info().device_count; i++) {
2817                ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();
2818                dev_ctx->description                       = aclrtGetSocName();
2819                dev_ctx->device                            = i;
2820                dev_ctx->name                              = GGML_CANN_NAME + std::to_string(i);
2821                dev_ctx->op_offload_min_batch_size         = min_batch_size;
2822                ggml_cann_set_device(i);
2823                ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface   = */ ggml_backend_cann_device_interface,
2824                                                                  /* .reg     = */ &reg,
2825                                                                  /* .context = */ dev_ctx };
2826                ctx->devices.push_back(dev);
2827            }
2828
2829            reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION,
2830                                    /* .iface       = */ ggml_backend_cann_reg_interface,
2831                                    /* .context     = */ ctx };
2832        }
2833
2834        initialized = true;
2835    }
2836
2837    return &reg;
2838}
2839
2840ggml_backend_t ggml_backend_cann_init(int32_t device) {
2841    aclInit(nullptr);
2842    if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
2843        GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
2844        return nullptr;
2845    }
2846
2847    ggml_backend_cann_context * ctx = new ggml_backend_cann_context(device);
2848    if (ctx == nullptr) {
2849        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2850        return nullptr;
2851    }
2852    ggml_cann_set_device(ctx->device);
2853    ggml_backend_t cann_backend =
2854        new ggml_backend{ /* .guid      = */ ggml_backend_cann_guid(),
2855                          /* .interface = */ ggml_backend_cann_interface,
2856                          /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2857                          /* .context   = */ ctx };
2858
2859    return cann_backend;
2860}
2861
2862bool ggml_backend_is_cann(ggml_backend_t backend) {
2863    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2864}
2865
2866int32_t ggml_backend_cann_get_device_count() {
2867    return ggml_cann_info().device_count;
2868}
2869
2870void ggml_backend_cann_get_device_description(int32_t device, char * description, size_t description_size) {
2871    ggml_cann_set_device(device);
2872    const char * soc_name = aclrtGetSocName();
2873    snprintf(description, description_size, "%s", soc_name);
2874}
2875
2876void ggml_backend_cann_get_device_memory(int32_t device, size_t * free, size_t * total) {
2877    ggml_cann_set_device(device);
2878    ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2879}
2880
2881GGML_BACKEND_DL_IMPL(ggml_backend_cann_reg)