1/*
2 * Copyright (c) 2023-2026 The ggml authors
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to
6 * deal in the Software without restriction, including without limitation the
7 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8 * sell copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20 * IN THE SOFTWARE.
21 */
22
23#ifndef CANN_COMMON_H
24#define CANN_COMMON_H
25
26#include "../ggml-impl.h"
27#include "../include/ggml-cann.h"
28#include "../include/ggml.h"
29
30#include <acl/acl.h>
31#include <unistd.h>
32
33#include <atomic>
34#include <condition_variable>
35#include <cstdio>
36#include <functional>
37#include <iostream>
38#include <list>
39#include <map>
40#include <memory>
41#include <mutex>
42#include <optional>
43#include <string>
44#include <thread>
45#include <vector>
46
47#define MATRIX_ROW_PADDING 512
48#define GGML_CANN_MAX_STREAMS 8
49
50/**
51 * @brief Handles CANN-related errors by printing an error message and
52 * terminating the program.
53 * @param stmt The statement that caused the error.
54 * @param func The function in which the error occurred.
55 * @param file The file in which the error occurred.
56 * @param line The line number at which the error occurred.
57 * @param msg The error message.
58 */
59[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
60
61/**
62 * @brief Checks the result of a CANN function call and invokes the error
63 * handler if the call fails.
64 * @param stmt The CANN function call to check.
65 * @param success The success code that indicates the call was successful.
66 * @param error_fn The function to call to retrieve the error message.
67 */
68#define ACL_CHECK_GEN(stmt, success, error_fn) \
69 do { \
70 int err_code = (stmt); \
71 if (err_code != (success)) { \
72 ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
73 } \
74 } while (0);
75
76#define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
77
78/**
79 * @brief Contains information about CANN devices.
80 */
81struct ggml_cann_device_info {
82 /**
83 * @brief Number of CANN devices available.
84 */
85 int32_t device_count;
86
87 /**
88 * @brief Information about a single CANN device.
89 */
90 struct cann_device_info {
91 int cc; /**< Compute capability. */
92 size_t smpb; /**< Maximum shared memory per block. */
93 bool vmm; /**< Virtual memory support. */
94 size_t vmm_granularity; /**< Granularity of virtual memory. */
95 size_t total_vram; /**< Total video RAM available on the device. */
96 };
97
98 cann_device_info devices[GGML_CANN_MAX_DEVICES] = {}; /**< Array of CANN device information. */
99};
100
101const ggml_cann_device_info & ggml_cann_info();
102
103void ggml_cann_set_device(int32_t device);
104
105std::optional<std::string> get_env_as_lowercase(const std::string & name);
106bool parse_bool(const std::string & value);
107int parse_integer(const std::string & value);
108
109/**
110 * @brief Abstract base class for memory pools used by CANN.
111 */
112struct ggml_cann_pool {
113 /**
114 * @brief Virtual destructor for the memory pool.
115 */
116 virtual ~ggml_cann_pool() = default;
117
118 /**
119 * @brief Allocates memory from the pool.
120 *
121 * @param size The size of the memory block to allocate.
122 * @param actual_size Pointer to a variable where the actual allocated size
123 * will be stored.
124 * @return Pointer to the allocated memory block.
125 */
126 virtual void * alloc(size_t size, size_t * actual_size) = 0;
127
128 /**
129 * @brief Frees a previously allocated memory block.
130 *
131 * @param ptr Pointer to the memory block to free.
132 * @param size Size of the memory block to free.
133 * @note Note that all CANN opertors are running async. Make sure memory is
134 * still avaiable before this operator finished.
135 */
136 virtual void free(void * ptr, size_t size) = 0;
137};
138
139/**
140 * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
141 */
142struct ggml_cann_pool_alloc {
143 ggml_cann_pool * pool = nullptr; /**< Pointer to the memory pool. */
144 void * ptr = nullptr; /**< Pointer to the allocated memory block. */
145 size_t actual_size = 0; /**< Actual size of the allocated memory block. */
146
147 /**
148 * @brief Default constructor.
149 */
150 ggml_cann_pool_alloc() = default;
151
152 /**
153 * @brief Constructor that initializes the memory pool.
154 * @param pool Reference to the memory pool.
155 */
156 explicit ggml_cann_pool_alloc(ggml_cann_pool & pool) : pool(&pool) {}
157
158 /**
159 * @brief Constructor that initializes the memory pool and allocates memory.
160 * @param pool Reference to the memory pool.
161 * @param size Size of the memory block to allocate.
162 */
163 ggml_cann_pool_alloc(ggml_cann_pool & pool, size_t size) : pool(&pool) { alloc(size); }
164
165 /**
166 * @brief Destructor that frees the allocated memory block.
167 */
168 ~ggml_cann_pool_alloc() {
169 if (ptr != nullptr) {
170 pool->free(ptr, actual_size);
171 }
172 }
173
174 /**
175 * @brief Allocates memory from the pool.
176 * @param size Size of the memory block to allocate.
177 * @return Pointer to the allocated memory block.
178 */
179 void * alloc(size_t size) {
180 GGML_ASSERT(pool != nullptr);
181 GGML_ASSERT(ptr == nullptr);
182 ptr = pool->alloc(size, &this->actual_size);
183 return ptr;
184 }
185
186 /**
187 * @brief Allocates memory from a specific memory pool.
188 * @param pool Reference to the memory pool.
189 * @param size Size of the memory block to allocate.
190 * @return Pointer to the allocated memory block.
191 */
192 void * alloc(ggml_cann_pool & pool, size_t size) {
193 this->pool = &pool;
194 return alloc(size);
195 }
196
197 /**
198 * @brief Gets the pointer to the allocated memory block.
199 * @return Pointer to the allocated memory block.
200 */
201 void * get() { return ptr; }
202
203 // Deleted copy constructor
204 ggml_cann_pool_alloc(const ggml_cann_pool_alloc &) = delete;
205
206 // Deleted move constructor
207 ggml_cann_pool_alloc(ggml_cann_pool_alloc &&) = delete;
208
209 // Deleted copy assignment operator
210 ggml_cann_pool_alloc & operator=(const ggml_cann_pool_alloc &) = delete;
211
212 // Deleted move assignment operator
213 ggml_cann_pool_alloc & operator=(ggml_cann_pool_alloc &&) = delete;
214};
215
216#ifdef USE_ACL_GRAPH
217struct ggml_graph_node_properties {
218 // dst tensor
219 void * node_address;
220 int64_t ne[GGML_MAX_DIMS];
221 size_t nb[GGML_MAX_DIMS];
222
223 // src tensor
224 void * src_address[GGML_MAX_SRC];
225 int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
226 size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
227
228 // op
229 ggml_op node_op;
230 int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
231
232 /**
233 * @brief Check if a ggml tensor node matches this property set.
234 *
235 * This function compares all relevant fields (address, op type, shape, source inputs, op params)
236 * to determine whether the current node matches these previously recorded properties.
237 *
238 * @param node The current ggml tensor node.
239 * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
240 */
241 bool has_matching_properties(ggml_tensor * node) {
242 if (node->data != this->node_address && node->op != GGML_OP_VIEW) {
243 return false;
244 }
245
246 if (node->op != this->node_op) {
247 return false;
248 }
249
250 for (int i = 0; i < GGML_MAX_DIMS; i++) {
251 if (node->ne[i] != this->ne[i]) {
252 return false;
253 }
254 if (node->nb[i] != this->nb[i]) {
255 return false;
256 }
257 }
258
259 for (int i = 0; i < GGML_MAX_SRC; i++) {
260 if (node->src[i]) {
261 if (node->src[i]->data != this->src_address[i] && node->op != GGML_OP_VIEW) {
262 return false;
263 }
264
265 for (int d = 0; d < GGML_MAX_DIMS; d++) {
266 if (node->src[i]->ne[d] != this->src_ne[i][d]) {
267 return false;
268 }
269 if (node->src[i]->nb[d] != this->src_nb[i][d]) {
270 return false;
271 }
272 }
273 } else {
274 if (this->src_address[i] != nullptr) {
275 return false;
276 }
277 }
278 }
279
280 if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
281 return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
282 }
283 return true;
284 }
285};
286
287struct ggml_cann_graph {
288 ~ggml_cann_graph() {
289 if (graph != nullptr) {
290 ACL_CHECK(aclmdlRIDestroy(graph));
291 }
292 }
293
294 aclmdlRI graph = nullptr;
295
296 std::vector<ggml_graph_node_properties> ggml_graph_properties;
297
298 /**
299 * @brief Create a new CANN graph from a ggml computation graph.
300 *
301 * This function creates a new ggml_cann_graph object and fills its node properties
302 * (operation type, dimensions, strides, input sources, and operation parameters)
303 * based on the current ggml computation graph.
304 *
305 * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
306 * - node address
307 * - operation type
308 * - shape (ne) and strides (nb)
309 * - source tensor addresses
310 * - operation parameters
311 *
312 * @param cgraph The current ggml computation graph.
313 * @return Pointer to the newly created ggml_cann_graph object.
314 */
315 static ggml_cann_graph * create_from_cgraph(ggml_cgraph * cgraph) {
316 ggml_cann_graph * new_graph = new ggml_cann_graph();
317 new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
318
319 for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
320 ggml_tensor * node = cgraph->nodes[node_idx];
321 auto & prop = new_graph->ggml_graph_properties[node_idx];
322
323 prop.node_address = node->data;
324 prop.node_op = node->op;
325
326 std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
327 std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
328
329 for (int src = 0; src < GGML_MAX_SRC; ++src) {
330 if (node->src[src]) {
331 prop.src_address[src] = node->src[src]->data;
332 std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
333 std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
334 } else {
335 prop.src_address[src] = nullptr;
336 std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
337 std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
338 }
339 }
340
341 memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
342 }
343
344 return new_graph;
345 }
346
347 /**
348 * @brief Check whether this CANN graph matches the given ggml computation graph.
349 *
350 * This function compares the number of nodes and each node's properties
351 * (operation type, dimensions, strides, inputs, and operation parameters)
352 * to determine whether this CANN graph matches the given ggml graph.
353 *
354 * @param cgraph The current ggml computation graph.
355 * @return true if this CANN graph matches the ggml graph; false otherwise.
356 */
357 bool matches_cgraph(ggml_cgraph * cgraph) {
358 if (this->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
359 return false;
360 }
361
362 for (int i = 0; i < cgraph->n_nodes; ++i) {
363 if (!this->ggml_graph_properties[i].has_matching_properties(cgraph->nodes[i])) {
364 return false;
365 }
366 }
367
368 return true;
369 }
370};
371
372/**
373 * @brief LRU cache for managing ggml_cann_graph objects.
374 *
375 * This class maintains a list of shared_ptr to ggml_cann_graph objects
376 * and enforces a maximum capacity. It provides methods to push new graphs,
377 * move existing graphs to the front (most recently used), and clear the cache.
378 */
379struct ggml_cann_graph_lru_cache {
380 size_t capacity; /**< Maximum number of graphs in the cache. */
381
382 std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */
383
384 ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env_as_lowercase("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); }
385
386 /**
387 * @brief Push a new graph to the front of the cache.
388 * If the cache exceeds capacity, the least recently used graph is deleted.
389 * @param new_node Pointer to the new ggml_cann_graph to cache.
390 * Ownership is transferred to the cache (cache will delete it).
391 */
392 void push(ggml_cann_graph * new_node) {
393 if (cache_list.size() >= capacity) {
394 ggml_cann_graph * old = cache_list.back();
395 cache_list.pop_back();
396 delete old; // free the old graph
397 }
398 cache_list.push_front(new_node);
399 }
400
401 /**
402 * @brief Clear all graphs from the cache (also frees memory).
403 */
404 void clear() {
405 for (auto ptr : cache_list) {
406 delete ptr;
407 }
408 cache_list.clear();
409 }
410
411 /**
412 * @brief Destructor that clears the cache and frees all cached graphs.
413 */
414 ~ggml_cann_graph_lru_cache() { clear(); }
415
416 /**
417 * @brief Find a cached CANN graph that matches the given ggml graph and move it to front.
418 *
419 * This function iterates through the cached CANN graphs stored in the LRU cache and
420 * compares them against the given ggml computation graph. If a matching graph is found,
421 * it is promoted to the front of the LRU cache and returned. Otherwise, the function
422 * returns nullptr.
423 *
424 * @param cgraph The current ggml computation graph.
425 * @return true if found; false otherwise.
426 */
427 bool find_and_move_to_front(ggml_cgraph * cgraph) {
428 for (auto & graph_ptr : this->cache_list) {
429 if (graph_ptr->matches_cgraph(cgraph)) {
430 cache_list.remove(graph_ptr);
431 cache_list.push_front(graph_ptr);
432 return true;
433 }
434 }
435 return false;
436 }
437};
438#endif // USE_ACL_GRAPH
439
440struct ggml_cann_rope_cache {
441 ~ggml_cann_rope_cache() {
442 if (theta_scale_cache) {
443 ACL_CHECK(aclrtFree(theta_scale_cache));
444 }
445 if (sin_cache) {
446 ACL_CHECK(aclrtFree(sin_cache));
447 }
448 if (cos_cache) {
449 ACL_CHECK(aclrtFree(cos_cache));
450 }
451 if (position_select_index) {
452 ACL_CHECK(aclrtFree(position_select_index));
453 }
454 if (theta_scale_exp_host) {
455 free(theta_scale_exp_host);
456 }
457 if (position_select_index_host) {
458 free(position_select_index_host);
459 }
460 if (yarn_ramp_cache) {
461 ACL_CHECK(aclrtFree(yarn_ramp_cache));
462 }
463 }
464
465 bool equal(int64_t theta_scale_length,
466 int64_t position_length,
467 float ext_factor,
468 float theta_scale,
469 float freq_scale,
470 float attn_factor,
471 bool is_neox,
472 bool indep_sects,
473 bool mrope_used,
474 bool is_imrope,
475 int sections[4]) {
476 return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
477 this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
478 this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
479 this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
480 this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
481 }
482
483 void set(int64_t theta_scale_length,
484 int64_t position_length,
485 float ext_factor,
486 float theta_scale,
487 float freq_scale,
488 float attn_factor,
489 bool is_neox,
490 bool indep_sects,
491 bool mrope_used,
492 bool is_imrope,
493 int sections[4]) {
494 this->theta_scale_length = theta_scale_length;
495 this->position_length = position_length;
496 this->ext_factor = ext_factor;
497 this->theta_scale = theta_scale;
498 this->freq_scale = freq_scale;
499 this->attn_factor = attn_factor;
500 this->is_neox = is_neox;
501 this->indep_sects = indep_sects;
502 this->mrope_used = mrope_used;
503 this->is_imrope = is_imrope;
504 this->sections[0] = sections[0];
505 this->sections[1] = sections[1];
506 this->sections[2] = sections[2];
507 this->sections[3] = sections[3];
508 }
509
510 // memory cache, prepare before inferencing.
511 void * theta_scale_cache = nullptr;
512 float * theta_scale_exp_host = nullptr;
513 int * position_select_index_host = nullptr;
514 void * position_select_index = nullptr;
515 void * yarn_ramp_cache = nullptr;
516 // sin/cos cache, used only to accelerate first layer on each device
517 void * sin_cache = nullptr;
518 void * cos_cache = nullptr;
519 // Properties to check before reusing the sincos cache
520 int64_t theta_scale_length = 0;
521 int64_t position_length = 0;
522 bool cached = false;
523 float ext_factor = 0.0f;
524 float theta_scale = 0.0f;
525 float freq_scale = 0.0f;
526 float attn_factor = 0.0f;
527 bool is_neox = false;
528 bool indep_sects = false;
529 bool mrope_used = false;
530 int sections[4] = { 0, 0, 0, 0 };
531 bool is_imrope = false;
532};
533
534struct ggml_cann_tensor_cache {
535 ~ggml_cann_tensor_cache() {
536 if (cache != nullptr) {
537 ACL_CHECK(aclrtFree(cache));
538 }
539 }
540
541 void * cache = nullptr;
542 int64_t size = 0;
543};
544
545/**
546 * @brief Context for managing CANN backend operations.
547 */
548struct ggml_backend_cann_context {
549 int32_t device; /**< Device ID. */
550 std::string name; /**< Name of the device. */
551 std::string description; /**< Description of the device. */
552 aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
553#ifdef USE_ACL_GRAPH
554 /// Cached CANN ACL graph used for executing the current ggml computation graph.
555 ggml_cann_graph_lru_cache graph_lru_cache;
556 bool acl_graph_mode = true;
557#endif
558 bool async_mode;
559 // Rope Cache
560 ggml_cann_rope_cache rope_cache;
561 // Constant Pool
562 ggml_cann_tensor_cache rms_norm_one_tensor_cache;
563 ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
564
565 aclrtStream streams[GGML_CANN_MAX_STREAMS] = { nullptr }; /**< Array of streams for the device. */
566
567 /**
568 * @brief Constructor for initializing the context with a given device.
569 * @param device Device ID.
570 */
571 explicit ggml_backend_cann_context(int device) : device(device), name("CANN" + std::to_string(device)) {
572 ggml_cann_set_device(device);
573 description = aclrtGetSocName();
574
575#ifdef USE_ACL_GRAPH
576 acl_graph_mode = parse_bool(get_env_as_lowercase("GGML_CANN_ACL_GRAPH").value_or("on"));
577 GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER",
578 acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
579#endif
580 }
581
582 /**
583 * @brief Destructor for cleaning up resources.
584 */
585 ~ggml_backend_cann_context() {
586 ggml_cann_set_device(device);
587 if (copy_event != nullptr) {
588 ACL_CHECK(aclrtDestroyEvent(copy_event));
589 }
590 for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
591 if (streams[i] != nullptr) {
592 ACL_CHECK(aclrtDestroyStream(streams[i]));
593 }
594 }
595 }
596
597 /**
598 * @brief Get or create a stream for a given index.
599 * @param stream Index of the stream.
600 * @return The stream corresponding to the given index.
601 */
602 aclrtStream stream(int stream) {
603 if (streams[stream] == nullptr) {
604 // If the device is not set here, destroying the stream later may cause a mismatch
605 // between the thread contexts where the stream was created and destroyed.
606 // However, I printed the device_id, thread_id, and stream, and they are all consistent.
607 ACL_CHECK(aclrtSetDevice(device));
608 ACL_CHECK(aclrtCreateStream(&streams[stream]));
609 }
610 return streams[stream];
611 }
612
613 /**
614 * @brief Get or create the default stream (index 0).
615 * @return The default stream.
616 */
617 aclrtStream stream() { return stream(0); }
618
619 // TODO: each stream should have a memory pool.
620 std::unique_ptr<ggml_cann_pool> mem_pool; /**< Memory pool for the device. */
621
622 /**
623 * @brief Create a new memory pool for a given device.
624 * @param device Device ID.
625 * @return A unique pointer to the new memory pool.
626 */
627 static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
628
629 /**
630 * @brief Get or create the memory pool for the context.
631 * @return Reference to the memory pool.
632 */
633 ggml_cann_pool & pool() {
634 if (mem_pool == nullptr) {
635 mem_pool = new_pool_for_device(device);
636 }
637 return *mem_pool;
638 }
639};
640
641#endif // CANN_COMMON_H