1/*
  2 * Copyright (c) 2023-2026 The ggml authors
  3 *
  4 * Permission is hereby granted, free of charge, to any person obtaining a copy
  5 * of this software and associated documentation files (the "Software"), to
  6 * deal in the Software without restriction, including without limitation the
  7 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
  8 * sell copies of the Software, and to permit persons to whom the Software is
  9 * furnished to do so, subject to the following conditions:
 10 *
 11 * The above copyright notice and this permission notice shall be included in
 12 * all copies or substantial portions of the Software.
 13 *
 14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 19 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 20 * IN THE SOFTWARE.
 21 */
 22
 23#ifndef CANN_COMMON_H
 24#define CANN_COMMON_H
 25
 26#include "../ggml-impl.h"
 27#include "../include/ggml-cann.h"
 28#include "../include/ggml.h"
 29
 30#include <acl/acl.h>
 31#include <unistd.h>
 32
 33#include <atomic>
 34#include <condition_variable>
 35#include <cstdio>
 36#include <functional>
 37#include <iostream>
 38#include <list>
 39#include <map>
 40#include <memory>
 41#include <mutex>
 42#include <optional>
 43#include <string>
 44#include <thread>
 45#include <vector>
 46
 47#define MATRIX_ROW_PADDING    512
 48#define GGML_CANN_MAX_STREAMS 8
 49
 50/**
 51 * @brief Handles CANN-related errors by printing an error message and
 52 *        terminating the program.
 53 * @param stmt The statement that caused the error.
 54 * @param func The function in which the error occurred.
 55 * @param file The file in which the error occurred.
 56 * @param line The line number at which the error occurred.
 57 * @param msg The error message.
 58 */
 59[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
 60
 61/**
 62 * @brief Checks the result of a CANN function call and invokes the error
 63 *        handler if the call fails.
 64 * @param stmt The CANN function call to check.
 65 * @param success The success code that indicates the call was successful.
 66 * @param error_fn The function to call to retrieve the error message.
 67 */
 68#define ACL_CHECK_GEN(stmt, success, error_fn)                                \
 69    do {                                                                      \
 70        int err_code = (stmt);                                                \
 71        if (err_code != (success)) {                                          \
 72            ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
 73        }                                                                     \
 74    } while (0);
 75
 76#define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
 77
 78/**
 79 * @brief Contains information about CANN devices.
 80 */
 81struct ggml_cann_device_info {
 82    /**
 83     * @brief Number of CANN devices available.
 84     */
 85    int32_t device_count;
 86
 87    /**
 88     * @brief Information about a single CANN device.
 89     */
 90    struct cann_device_info {
 91        int    cc;              /**< Compute capability.                   */
 92        size_t smpb;            /**< Maximum shared memory per block.      */
 93        bool   vmm;             /**< Virtual memory support.               */
 94        size_t vmm_granularity; /**< Granularity of virtual memory.        */
 95        size_t total_vram;      /**< Total video RAM available on the device. */
 96    };
 97
 98    cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */
 99};
100
101const ggml_cann_device_info & ggml_cann_info();
102
103void    ggml_cann_set_device(int32_t device);
104
105std::optional<std::string> get_env_as_lowercase(const std::string & name);
106bool                       parse_bool(const std::string & value);
107int                        parse_integer(const std::string & value);
108
109/**
110 * @brief Abstract base class for memory pools used by CANN.
111 */
112struct ggml_cann_pool {
113    /**
114     * @brief Virtual destructor for the memory pool.
115     */
116    virtual ~ggml_cann_pool() = default;
117
118    /**
119     * @brief Allocates memory from the pool.
120     *
121     * @param size         The size of the memory block to allocate.
122     * @param actual_size  Pointer to a variable where the actual allocated size
123     *                     will be stored.
124     * @return             Pointer to the allocated memory block.
125     */
126    virtual void * alloc(size_t size, size_t * actual_size) = 0;
127
128    /**
129     * @brief Frees a previously allocated memory block.
130     *
131     * @param ptr   Pointer to the memory block to free.
132     * @param size  Size of the memory block to free.
133     * @note Note that all CANN opertors are running async. Make sure memory is
134     *       still avaiable before this operator finished.
135     */
136    virtual void free(void * ptr, size_t size) = 0;
137};
138
139/**
140 * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
141 */
142struct ggml_cann_pool_alloc {
143    ggml_cann_pool * pool        = nullptr; /**< Pointer to the memory pool. */
144    void *           ptr         = nullptr; /**< Pointer to the allocated memory block. */
145    size_t           actual_size = 0;       /**< Actual size of the allocated memory block. */
146
147    /**
148     * @brief Default constructor.
149     */
150    ggml_cann_pool_alloc() = default;
151
152    /**
153     * @brief Constructor that initializes the memory pool.
154     * @param pool Reference to the memory pool.
155     */
156    explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {}
157
158    /**
159     * @brief Constructor that initializes the memory pool and allocates memory.
160     * @param pool Reference to the memory pool.
161     * @param size Size of the memory block to allocate.
162     */
163    ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); }
164
165    /**
166     * @brief Destructor that frees the allocated memory block.
167     */
168    ~ggml_cann_pool_alloc() {
169        if (ptr != nullptr) {
170            pool->free(ptr, actual_size);
171        }
172    }
173
174    /**
175     * @brief Allocates memory from the pool.
176     * @param size Size of the memory block to allocate.
177     * @return Pointer to the allocated memory block.
178     */
179    void * alloc(size_t size) {
180        GGML_ASSERT(pool != nullptr);
181        GGML_ASSERT(ptr == nullptr);
182        ptr = pool->alloc(size, &this->actual_size);
183        return ptr;
184    }
185
186    /**
187     * @brief Allocates memory from a specific memory pool.
188     * @param pool Reference to the memory pool.
189     * @param size Size of the memory block to allocate.
190     * @return Pointer to the allocated memory block.
191     */
192    void * alloc(ggml_cann_pool & pool, size_t size) {
193        this->pool = &pool;
194        return alloc(size);
195    }
196
197    /**
198     * @brief Gets the pointer to the allocated memory block.
199     * @return Pointer to the allocated memory block.
200     */
201    void * get() { return ptr; }
202
203    // Deleted copy constructor
204    ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete;
205
206    // Deleted move constructor
207    ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete;
208
209    // Deleted copy assignment operator
210    ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete;
211
212    // Deleted move assignment operator
213    ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
214};
215
216#ifdef USE_ACL_GRAPH
217struct ggml_graph_node_properties {
218    // dst tensor
219    void *  node_address;
220    int64_t ne[GGML_MAX_DIMS];
221    size_t  nb[GGML_MAX_DIMS];
222
223    // src tensor
224    void *  src_address[GGML_MAX_SRC];
225    int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
226    size_t  src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
227
228    // op
229    ggml_op node_op;
230    int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
231
232    /**
233     * @brief Check if a ggml tensor node matches this property set.
234     *
235     * This function compares all relevant fields (address, op type, shape, source inputs, op params)
236     * to determine whether the current node matches these previously recorded properties.
237     *
238     * @param node The current ggml tensor node.
239     * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
240     */
241    bool has_matching_properties(ggml_tensor * node) {
242        if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
243            return false;
244        }
245
246        if (node->op != this->node_op) {
247            return false;
248        }
249
250        for (int i = 0; i < GGML_MAX_DIMS; i++) {
251            if (node->ne[i] != this->ne[i]) {
252                return false;
253            }
254            if (node->nb[i] != this->nb[i]) {
255                return false;
256            }
257        }
258
259        for (int i = 0; i < GGML_MAX_SRC; i++) {
260            if (node->src[i]) {
261                if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
262                    return false;
263                }
264
265                for (int d = 0; d < GGML_MAX_DIMS; d++) {
266                    if (node->src[i]->ne[d] != this->src_ne[i][d]) {
267                        return false;
268                    }
269                    if (node->src[i]->nb[d] != this->src_nb[i][d]) {
270                        return false;
271                    }
272                }
273            } else {
274                if (this->src_address[i] != nullptr) {
275                    return false;
276                }
277            }
278        }
279
280        if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
281            return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
282        }
283        return true;
284    }
285};
286
287struct ggml_cann_graph {
288    ~ggml_cann_graph() {
289        if (graph != nullptr) {
290            ACL_CHECK(aclmdlRIDestroy(graph));
291        }
292    }
293
294    aclmdlRI graph = nullptr;
295
296    std::vector<ggml_graph_node_properties> ggml_graph_properties;
297
298    /**
299     * @brief Create a new CANN graph from a ggml computation graph.
300     *
301     * This function creates a new ggml_cann_graph object and fills its node properties
302     * (operation type, dimensions, strides, input sources, and operation parameters)
303     * based on the current ggml computation graph.
304     *
305     * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
306     * - node address
307     * - operation type
308     * - shape (ne) and strides (nb)
309     * - source tensor addresses
310     * - operation parameters
311     *
312     * @param cgraph The current ggml computation graph.
313     * @return Pointer to the newly created ggml_cann_graph object.
314     */
315    static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
316        ggml_cann_graph * new_graph = new ggml_cann_graph();
317        new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
318
319        for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
320            ggml_tensor * node = cgraph->nodes[node_idx];
321            auto &        prop = new_graph->ggml_graph_properties[node_idx];
322
323            prop.node_address = node->data;
324            prop.node_op      = node->op;
325
326            std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
327            std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
328
329            for (int src = 0; src < GGML_MAX_SRC; ++src) {
330                if (node->src[src]) {
331                    prop.src_address[src] = node->src[src]->data;
332                    std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
333                    std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
334                } else {
335                    prop.src_address[src] = nullptr;
336                    std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
337                    std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
338                }
339            }
340
341            memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
342        }
343
344        return new_graph;
345    }
346
347    /**
348     * @brief Check whether this CANN graph matches the given ggml computation graph.
349     *
350     * This function compares the number of nodes and each node's properties
351     * (operation type, dimensions, strides, inputs, and operation parameters)
352     * to determine whether this CANN graph matches the given ggml graph.
353     *
354     * @param cgraph The current ggml computation graph.
355     * @return true if this CANN graph matches the ggml graph; false otherwise.
356     */
357    bool matches_cgraph(ggml_cgraph * cgraph) {
358        if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
359            return false;
360        }
361
362        for (int i = 0; i < cgraph->n_nodes; ++i) {
363            if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
364                return false;
365            }
366        }
367
368        return true;
369    }
370};
371
372/**
373 * @brief LRU cache for managing ggml_cann_graph objects.
374 *
375 * This class maintains a list of shared_ptr to ggml_cann_graph objects
376 * and enforces a maximum capacity. It provides methods to push new graphs,
377 * move existing graphs to the front (most recently used), and clear the cache.
378 */
379struct ggml_cann_graph_lru_cache {
380    size_t capacity;                         /**< Maximum number of graphs in the cache. */
381
382    std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */
383
384    ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env_as_lowercase("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
385
386    /**
387     * @brief Push a new graph to the front of the cache.
388     * If the cache exceeds capacity, the least recently used graph is deleted.
389     * @param new_node Pointer to the new ggml_cann_graph to cache.
390     *        Ownership is transferred to the cache (cache will delete it).
391     */
392    void push(ggml_cann_graph * new_node) {
393        if (cache_list.size() >= capacity) {
394            ggml_cann_graph * old = cache_list.back();
395            cache_list.pop_back();
396            delete old;  // free the old graph
397        }
398        cache_list.push_front(new_node);
399    }
400
401    /**
402     * @brief Clear all graphs from the cache (also frees memory).
403     */
404    void clear() {
405        for (auto ptr : cache_list) {
406            delete ptr;
407        }
408        cache_list.clear();
409    }
410
411    /**
412     * @brief Destructor that clears the cache and frees all cached graphs.
413     */
414    ~ggml_cann_graph_lru_cache() { clear(); }
415
416    /**
417     * @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
418     *
419     * This function iterates through the cached CANN graphs stored in the LRU cache and
420     * compares them against the given ggml computation graph. If a matching graph is found,
421     * it is promoted to the front of the LRU cache and returned. Otherwise, the function
422     * returns nullptr.
423     *
424     * @param cgraph The current ggml computation graph.
425     * @return true if found; false otherwise.
426     */
427    bool find_and_move_to_front(ggml_cgraph * cgraph) {
428        for (auto & graph_ptr : this->cache_list) {
429            if (graph_ptr->matches_cgraph(cgraph)) {
430                cache_list.remove(graph_ptr);
431                cache_list.push_front(graph_ptr);
432                return true;
433            }
434        }
435        return false;
436    }
437};
438#endif  // USE_ACL_GRAPH
439
440struct ggml_cann_rope_cache {
441    ~ggml_cann_rope_cache() {
442        if (theta_scale_cache) {
443            ACL_CHECK(aclrtFree(theta_scale_cache));
444        }
445        if (sin_cache) {
446            ACL_CHECK(aclrtFree(sin_cache));
447        }
448        if (cos_cache) {
449            ACL_CHECK(aclrtFree(cos_cache));
450        }
451        if (position_select_index) {
452            ACL_CHECK(aclrtFree(position_select_index));
453        }
454        if (theta_scale_exp_host) {
455            free(theta_scale_exp_host);
456        }
457        if (position_select_index_host) {
458            free(position_select_index_host);
459        }
460        if (yarn_ramp_cache) {
461            ACL_CHECK(aclrtFree(yarn_ramp_cache));
462        }
463    }
464
465    bool equal(int64_t theta_scale_length,
466               int64_t position_length,
467               float   ext_factor,
468               float   theta_scale,
469               float   freq_scale,
470               float   attn_factor,
471               bool    is_neox,
472               bool    indep_sects,
473               bool    mrope_used,
474               bool    is_imrope,
475               int     sections[4]) {
476        return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
477               this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
478               this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
479               this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
480               this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
481    }
482
483    void set(int64_t theta_scale_length,
484             int64_t position_length,
485             float   ext_factor,
486             float   theta_scale,
487             float   freq_scale,
488             float   attn_factor,
489             bool    is_neox,
490             bool    indep_sects,
491             bool    mrope_used,
492             bool    is_imrope,
493             int     sections[4]) {
494        this->theta_scale_length = theta_scale_length;
495        this->position_length    = position_length;
496        this->ext_factor         = ext_factor;
497        this->theta_scale        = theta_scale;
498        this->freq_scale         = freq_scale;
499        this->attn_factor        = attn_factor;
500        this->is_neox            = is_neox;
501        this->indep_sects        = indep_sects;
502        this->mrope_used         = mrope_used;
503        this->is_imrope          = is_imrope;
504        this->sections[0]        = sections[0];
505        this->sections[1]        = sections[1];
506        this->sections[2]        = sections[2];
507        this->sections[3]        = sections[3];
508    }
509
510    // memory cache, prepare before inferencing.
511    void *  theta_scale_cache          = nullptr;
512    float * theta_scale_exp_host       = nullptr;
513    int *   position_select_index_host = nullptr;
514    void *  position_select_index      = nullptr;
515    void *  yarn_ramp_cache            = nullptr;
516    // sin/cos cache, used only to accelerate first layer on each device
517    void *  sin_cache                  = nullptr;
518    void *  cos_cache                  = nullptr;
519    // Properties to check before reusing the sincos cache
520    int64_t theta_scale_length         = 0;
521    int64_t position_length            = 0;
522    bool    cached                     = false;
523    float   ext_factor                 = 0.0f;
524    float   theta_scale                = 0.0f;
525    float   freq_scale                 = 0.0f;
526    float   attn_factor                = 0.0f;
527    bool    is_neox                    = false;
528    bool    indep_sects                = false;
529    bool    mrope_used                 = false;
530    int     sections[4]                = { 0, 0, 0, 0 };
531    bool    is_imrope                  = false;
532};
533
534struct ggml_cann_tensor_cache {
535    ~ggml_cann_tensor_cache() {
536        if (cache != nullptr) {
537            ACL_CHECK(aclrtFree(cache));
538        }
539    }
540
541    void *  cache = nullptr;
542    int64_t size  = 0;
543};
544
545/**
546 * @brief Context for managing CANN backend operations.
547 */
548struct ggml_backend_cann_context {
549    int32_t     device;               /**< Device ID. */
550    std::string name;                 /**< Name of the device. */
551    std::string description;          /**< Description of the device. */
552    aclrtEvent  copy_event = nullptr; /**< Event for managing copy operations. */
553#ifdef USE_ACL_GRAPH
554    /// Cached CANN ACL graph used for executing the current ggml computation graph.
555    ggml_cann_graph_lru_cache graph_lru_cache;
556    bool                      acl_graph_mode = true;
557#endif
558    bool                   async_mode;
559    // Rope Cache
560    ggml_cann_rope_cache   rope_cache;
561    // Constant Pool
562    ggml_cann_tensor_cache rms_norm_one_tensor_cache;
563    ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
564
565    aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
566
567    /**
568     * @brief Constructor for initializing the context with a given device.
569     * @param device Device ID.
570     */
571    explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
572        ggml_cann_set_device(device);
573        description = aclrtGetSocName();
574
575#ifdef USE_ACL_GRAPH
576        acl_graph_mode = parse_bool(get_env_as_lowercase("GGML_CANN_ACL_GRAPH").value_or("on"));
577        GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
578                      acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
579#endif
580    }
581
582    /**
583     * @brief Destructor for cleaning up resources.
584     */
585    ~ggml_backend_cann_context() {
586        ggml_cann_set_device(device);
587        if (copy_event != nullptr) {
588            ACL_CHECK(aclrtDestroyEvent(copy_event));
589        }
590        for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
591            if (streams[i] != nullptr) {
592                ACL_CHECK(aclrtDestroyStream(streams[i]));
593            }
594        }
595    }
596
597    /**
598     * @brief Get or create a stream for a given index.
599     * @param stream Index of the stream.
600     * @return The stream corresponding to the given index.
601     */
602    aclrtStream stream(int stream) {
603        if (streams[stream] == nullptr) {
604            // If the device is not set here, destroying the stream later may cause a mismatch
605            // between the thread contexts where the stream was created and destroyed.
606            // However, I printed the device_id, thread_id, and stream, and they are all consistent.
607            ACL_CHECK(aclrtSetDevice(device));
608            ACL_CHECK(aclrtCreateStream(&streams[stream]));
609        }
610        return streams[stream];
611    }
612
613    /**
614     * @brief Get or create the default stream (index 0).
615     * @return The default stream.
616     */
617    aclrtStream stream() { return stream(0); }
618
619    // TODO: each stream should have a memory pool.
620    std::unique_ptr<ggml_cann_pool> mem_pool; /**< Memory pool for the device. */
621
622    /**
623     * @brief Create a new memory pool for a given device.
624     * @param device Device ID.
625     * @return A unique pointer to the new memory pool.
626     */
627    static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
628
629    /**
630     * @brief Get or create the memory pool for the context.
631     * @return Reference to the memory pool.
632     */
633    ggml_cann_pool & pool() {
634        if (mem_pool == nullptr) {
635            mem_pool = new_pool_for_device(device);
636        }
637        return *mem_pool;
638    }
639};
640
641#endif  // CANN_COMMON_H