summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp')
-rw-r--r--llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp2881
1 files changed, 2881 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp b/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp
new file mode 100644
index 0000000..3f3de9f
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp
@@ -0,0 +1,2881 @@
1/*
2 * Copyright (c) 2023-2026 The ggml authors
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a copy
5 * of this software and associated documentation files (the "Software"), to
6 * deal in the Software without restriction, including without limitation the
7 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8 * sell copies of the Software, and to permit persons to whom the Software is
9 * furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice shall be included in
12 * all copies or substantial portions of the Software.
13 *
14 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20 * IN THE SOFTWARE.
21 */
22
23#include "ggml-cann.h"
24
25#include "ggml-backend-impl.h"
26#include "ggml-cann/aclnn_ops.h"
27#include "ggml-cann/common.h"
28#include "ggml-impl.h"
29#include "ggml.h"
30
31#include <acl/acl.h>
32#include <aclnnop/aclnn_trans_matmul_weight.h>
33#include <stdarg.h>
34
35#include <chrono>
36#include <cmath>
37#include <cstdio>
38#include <cstring>
39#include <mutex>
40#include <optional>
41#include <queue>
42#include <unordered_set>
43
44#define GGML_COMMON_DECL_C
45
46#include "ggml-common.h"
47
48#define GGML_CANN_NAME "CANN"
49
50/**
51 * @brief Handles CANN errors by printing an error message and aborting.
52 *
53 * @param stmt The statement that caused the error.
54 * @param func The function in which the error occurred.
55 * @param file The file in which the error occurred.
56 * @param line The line number where the error occurred.
57 * @param msg The error message.
58 */
59[[noreturn]] void ggml_cann_error(const char * stmt, const char * func, const char * file, int line, const char * msg) {
60 int32_t id = -1;
61 aclrtGetDevice(&id);
62
63 GGML_LOG_ERROR("CANN error: %s\n", msg);
64 GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line);
65 GGML_LOG_ERROR(" %s\n", stmt);
66 // abort with GGML_ASSERT to get a stack trace
67 GGML_ABORT("CANN error");
68}
69
70// Thread-local variable to record the current device of this thread.
71thread_local int g_current_cann_device = -1;
72
73/**
74 * @brief Set the CANN device to be used.
75 *
76 * @param device The target device ID to set.
77 */
78void ggml_cann_set_device(const int32_t device) {
79 // int current_device = -1;
80 // Note: In some CANN versions, if no device has been set yet,
81 // aclrtGetDevice(&current_device) may return 0 by default.
82 // aclrtGetDevice(&current_device);
83
84 // If the current device is already the target one, no need to switch.
85 if (device == g_current_cann_device) {
86 return;
87 }
88
89 // Switch to the new device.
90 ACL_CHECK(aclrtSetDevice(device));
91
92 // Update the global device record.
93 g_current_cann_device = device;
94}
95
96/**
97 * @brief Get the value of the specified environment variable (name) as lowercase.
98 * if not empty, return a std::string object
99 */
100std::optional<std::string> get_env_as_lowercase(const std::string & name) {
101 const char * val = std::getenv(name.c_str());
102 if (!val) {
103 return std::nullopt;
104 }
105 std::string res = std::string(val);
106 std::transform(res.begin(), res.end(), res.begin(), ::tolower);
107 return res;
108}
109
110/**
111 * @brief Verify whether the environment variable is a valid value.
112 */
113bool parse_bool(const std::string & value) {
114 static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
115 return valid_values.find(value) != valid_values.end();
116}
117
118/**
119 * @brief Parse a string as an integer, returning 0 if invalid.
120 *
121 * This function attempts to convert the input string `value` to an `int`.
122 * If the string is not a valid integer or is out of the `int` range,
123 * it returns 0.
124 *
125 * @param value The string to parse.
126 * @return The parsed integer, or 0 if conversion fails.
127 */
128int parse_integer(const std::string & value) {
129 try {
130 return std::stoi(value);
131 } catch (...) {
132 return 0;
133 }
134}
135
136/**
137 * @brief Initialize the CANN device information.
138 *
139 * This function initializes the CANN device information by obtaining the
140 * device count and setting the memory allocation granularity for each device.
141 *
142 * @return A structure containing the device information.
143 */
144static ggml_cann_device_info ggml_cann_init() {
145 ggml_cann_device_info info = {};
146
147 aclError err = aclrtGetDeviceCount((uint32_t *) &info.device_count);
148
149 if (err != ACL_SUCCESS) {
150 GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n", __func__, aclGetRecentErrMsg());
151 return info;
152 }
153
154 GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
155
156 for (int id = 0; id < info.device_count; ++id) {
157 aclrtPhysicalMemProp prop = {};
158 prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
159 prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
160 prop.memAttr = ACL_HBM_MEM_HUGE;
161 prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
162 prop.location.id = id;
163 prop.reserve = 0;
164 err = aclrtMemGetAllocationGranularity(&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
165 &info.devices[id].vmm_granularity);
166 info.devices[id].vmm = err == ACL_SUCCESS;
167
168 size_t free, total;
169 ggml_backend_cann_get_device_memory(id, &free, &total);
170 info.devices[id].total_vram = free;
171 }
172
173 // TODO: add more device info later.
174 return info;
175}
176
177/**
178 * @brief Retrieve the CANN device information.
179 *
180 * This function returns a reference to a structure containing the CANN device
181 * information. The device information is initialized once and reused on
182 * subsequent calls.
183 *
184 * @return A reference to the structure containing the device information.
185 */
186const ggml_cann_device_info & ggml_cann_info() {
187 static ggml_cann_device_info info = ggml_cann_init();
188 return info;
189}
190
191//#define DEBUG_CANN_MALLOC
192/**
193 * @brief A pool of CANN buffers(priority segment buffer).
194 *
195 * This class manages a pool of CANN buffers for a specific device.
196 */
197struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
198 /**
199 * @brief The maximum reuse margin for a buffer.
200 */
201 static const size_t max_reuse_margin = 1ull << 22; // 4MB
202
203 /**
204 * @brief The minimum free margin for a buffer.
205 */
206 static const size_t min_free_margin = 1ull << 20; // 1MB
207
208 /**
209 * @brief The alignment for buffer allocation.
210 */
211 static const size_t alignment = 128;
212
213 /**
214 * @brief The device ID associated with this buffer pool.
215 */
216 int device;
217
218 /**
219 * @brief Whether to disable clean during buffer allocation.
220 */
221 bool disable_clean = false;
222
223 /**
224 * @brief Structure representing a CANN buffer.
225 */
226 struct ggml_cann_buffer {
227 void * ptr = nullptr; ///< Pointer to the buffer.
228 size_t size = 0; ///< Size of the buffer.
229 std::chrono::steady_clock::time_point last_used; ///< Last used time.
230
231 bool operator>(const ggml_cann_buffer & other) const { return size > other.size; }
232 };
233
234 /**
235 * @brief Array of CANN buffers in the pool.
236 */
237 std::unordered_map<void *, size_t> buffer_pool;
238 std::priority_queue<ggml_cann_buffer, std::vector<ggml_cann_buffer>, std::greater<>> free_buffers;
239
240 /**
241 * @brief Total size of all buffers in the pool.
242 */
243 size_t pool_size = 0;
244
245 /**
246 * @brief Constructor to initialize the buffer pool for a specific device.
247 *
248 * @param device The device ID to associate with this buffer pool.
249 */
250 explicit ggml_cann_pool_buf_prio(int device) : device(device) {
251 disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
252 }
253
254 /**
255 * @brief Destructor to free all buffers in the pool.
256 */
257 ~ggml_cann_pool_buf_prio() {
258 ggml_cann_set_device(device);
259 for (auto & [b_ptr, b_size] : buffer_pool) {
260 aclrtFree(b_ptr);
261 pool_size -= b_size;
262 }
263 buffer_pool.clear();
264 GGML_ASSERT(pool_size == 0);
265 }
266
267 /**
268 * @brief Allocate a buffer of the given size.
269 *
270 * @param size The size of the buffer to allocate.
271 * @param actual_size A pointer to a variable to receive the actual size of
272 * the allocated buffer.
273 * @return A pointer to the allocated buffer.
274 */
275 void * alloc(size_t size, size_t * actual_size) override {
276 size = GGML_PAD(size, alignment);
277 if (size == 0) {
278 size = alignment;
279 }
280
281 void * ptr = nullptr;
282 auto now = std::chrono::steady_clock::now();
283
284 std::vector<ggml_cann_buffer> free_buffers_rest;
285 free_buffers_rest.reserve(free_buffers.size());
286 while (!free_buffers.empty()) {
287 auto b = free_buffers.top();
288 free_buffers.pop();
289
290 if (b.size >= size) {
291 // reuse the buffer if the size is enough
292 const size_t margin = b.size - size;
293 if (margin <= max_reuse_margin) {
294 *actual_size = b.size;
295 ptr = b.ptr;
296#ifdef DEBUG_CANN_MALLOC
297 GGML_LOG_INFO(
298 "cann pool[%d]: reused %p, "
299 "pool_size = %5u MB, "
300 "size = %5u MB, "
301 "margin = %5u MB\n",
302 device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
303 (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
304 (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
305#endif
306 break;
307 }
308 }
309
310 bool should_clean = !disable_clean && b.size > min_free_margin &&
311 std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
312 if (should_clean) {
313 // free the buffer if the size is needed to be freed
314 ACL_CHECK(aclrtFree(b.ptr));
315 pool_size -= b.size;
316 buffer_pool.erase(b.ptr);
317#ifdef DEBUG_CANN_MALLOC
318 GGML_LOG_INFO(
319 "cann pool[%d]: clean %p, "
320 "pool_size = %5u MB, "
321 "size = %5u MB\n",
322 device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
323 (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
324#endif
325 continue;
326 }
327 free_buffers_rest.push_back(b);
328 }
329 for (ggml_cann_buffer & b : free_buffers_rest) {
330 free_buffers.push(std::move(b));
331 }
332
333#ifdef DEBUG_CANN_MALLOC
334 GGML_LOG_INFO("cann pool[%d] free pool_size = %5u MB\n\n", device,
335 (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
336#endif
337 if (ptr != nullptr) {
338 return ptr;
339 }
340
341 // allocate a new buffer if no buffer can be reused
342 ggml_cann_set_device(device);
343 ACL_CHECK(aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
344 *actual_size = size;
345 pool_size += size;
346#ifdef DEBUG_CANN_MALLOC
347 GGML_LOG_INFO(
348 "cann pool[%d]: allocate %p, "
349 "pool_size = %5u MB, "
350 "size = %5u MB\n",
351 device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
352 (uint32_t) (GGML_PAD(size, 1048576) / 1048576));
353#endif
354 buffer_pool.emplace(ptr, size);
355 return ptr;
356 }
357
358 /**
359 * @brief Free a buffer and return it to the pool.
360 *
361 * @param ptr Pointer to the buffer to free.
362 * @param size Size of the buffer to free.
363 */
364 void free(void * ptr, size_t size) override {
365 GGML_UNUSED(size);
366 auto it = buffer_pool.find(ptr);
367 if (it == buffer_pool.end()) {
368 GGML_ABORT("cann pool[%d]: buffer %p not found in pool\n", device, ptr);
369 }
370
371 auto now = std::chrono::steady_clock::now();
372 free_buffers.emplace(ggml_cann_buffer{ ptr, it->second, now });
373#ifdef DEBUG_CANN_MALLOC
374 GGML_LOG_INFO(
375 "cann pool[%d]: return %p, "
376 "pool_size = %5u MB\n",
377 device, ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
378#endif
379 }
380};
381
382/**
383 * @brief A pool of CANN buffers(segment buffer).
384 *
385 * This class manages a pool of CANN buffers for a specific device.
386 */
387struct ggml_cann_pool_buf : public ggml_cann_pool {
388 /**
389 * @brief The maximum reuse margin for a buffer.
390 */
391 static const size_t max_reuse_margin = 1ull << 22; // 4MB
392
393 /**
394 * @brief The minimum free margin for a buffer.
395 */
396 static const size_t min_free_margin = 1ull << 20; // 1MB
397
398 /**
399 * @brief The alignment for buffer allocation.
400 */
401 static const size_t alignment = 128;
402
403 /**
404 * @brief The maximum number of buffers in the pool.
405 */
406 static const int MAX_BUFFERS = 256;
407
408 /**
409 * @brief The device ID associated with this buffer pool.
410 */
411 int device;
412
413 /**
414 * @brief Whether to disable clean during buffer allocation.
415 */
416 bool disable_clean = false;
417
418 /**
419 * @brief Structure representing a CANN buffer.
420 */
421 struct ggml_cann_buffer {
422 void * ptr = nullptr; ///< Pointer to the buffer memory.
423 size_t size = 0; ///< Size of the buffer.
424 bool used = false; ///< Whether the buffer is currently in use.
425 std::chrono::steady_clock::time_point last_used; ///< Last used time.
426 };
427
428 /**
429 * @brief Array of CANN buffers in the pool.
430 */
431 ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
432
433 /**
434 * @brief Total size of all buffers in the pool.
435 */
436 size_t pool_size = 0;
437
438 /**
439 * @brief Constructor to initialize the buffer pool for a specific device.
440 *
441 * @param device The device ID to associate with this buffer pool.
442 */
443 explicit ggml_cann_pool_buf(int device) : device(device) {
444 disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
445 }
446
447 /**
448 * @brief Destructor to free all buffers in the pool.
449 */
450 ~ggml_cann_pool_buf() {
451 ggml_cann_set_device(device);
452 for (int i = 0; i < MAX_BUFFERS; ++i) {
453 ggml_cann_buffer & b = buffer_pool[i];
454 if (b.ptr != nullptr) {
455 aclrtFree(b.ptr);
456 pool_size -= b.size;
457 }
458 }
459 GGML_ASSERT(pool_size == 0);
460 }
461
462 /**
463 * @brief Allocate a buffer of the given size.
464 *
465 * @param size The size of the buffer to allocate.
466 * @param actual_size A pointer to a variable to receive the actual size of
467 * the allocated buffer.
468 * @return A pointer to the allocated buffer.
469 */
470 void * alloc(size_t size, size_t * actual_size) override {
471 size = GGML_PAD(size, alignment);
472 if (size == 0) {
473 size = alignment;
474 }
475
476 void * ptr = nullptr;
477 auto now = std::chrono::steady_clock::now();
478
479 int i = 0;
480 for (; i < MAX_BUFFERS; ++i) {
481 ggml_cann_buffer & b = buffer_pool[i];
482 if (b.ptr == nullptr) {
483 break;
484 }
485 if (b.used) {
486 continue;
487 }
488 if (b.size >= size) {
489 // reuse the buffer if the size is enough
490 const size_t margin = b.size - size;
491 if (margin <= max_reuse_margin) {
492 *actual_size = b.size;
493 b.used = true;
494 ptr = b.ptr;
495#ifdef DEBUG_CANN_MALLOC
496 GGML_LOG_INFO(
497 "cann pool[%d]: reused %p, "
498 "pool_size = %5u MB, "
499 "size = %5u MB, "
500 "margin = %5u MB\n",
501 device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
502 (uint32_t) (GGML_PAD(size, 1048576) / 1048576),
503 (uint32_t) (GGML_PAD(margin, 1048576) / 1048576));
504#endif
505 break;
506 }
507 }
508
509 bool should_clean = !disable_clean && b.size > min_free_margin &&
510 std::chrono::duration_cast<std::chrono::milliseconds>(now - b.last_used).count() > 100;
511 if (should_clean) {
512 // free the buffer if the size is needed to be freed
513 ACL_CHECK(aclrtFree(b.ptr));
514 pool_size -= b.size;
515#ifdef DEBUG_CANN_MALLOC
516 GGML_LOG_INFO(
517 "cann pool[%d]: clean %p, "
518 "pool_size = %5u MB, "
519 "size = %5u MB\n",
520 device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
521 (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
522#endif
523 b.ptr = nullptr;
524 }
525 }
526 if (ptr != nullptr) {
527 return ptr;
528 }
529
530 if (i < MAX_BUFFERS) {
531 // allocate a new buffer if no buffer can be reused
532 ggml_cann_buffer & b = buffer_pool[i];
533 ggml_cann_set_device(device);
534 ACL_CHECK(aclrtMalloc(&b.ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
535 pool_size += size;
536 *actual_size = size;
537 b.size = size;
538 b.used = true;
539 if (i >= MAX_BUFFERS - 8) {
540 GGML_LOG_WARN("cann pool[%d]: slots almost full\n", device);
541 }
542#ifdef DEBUG_CANN_MALLOC
543 GGML_LOG_INFO(
544 "cann pool[%d]: allocate %p, "
545 "pool_size = %5u MB, "
546 "size = %5u MB\n",
547 device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576),
548 (uint32_t) (GGML_PAD(b.size, 1048576) / 1048576));
549#endif
550 return b.ptr;
551 }
552
553 GGML_ABORT("cann pool[%d]: slots full\n", device);
554 }
555
556 /**
557 * @brief Free a buffer and return it to the pool.
558 *
559 * @param ptr Pointer to the buffer to free.
560 * @param size Size of the buffer to free.
561 */
562 void free(void * ptr, size_t size) override {
563 GGML_UNUSED(size);
564 for (int i = 0; i < MAX_BUFFERS; ++i) {
565 ggml_cann_buffer & b = buffer_pool[i];
566 if (b.ptr != ptr) {
567 continue;
568 }
569 b.used = false;
570 b.last_used = std::chrono::steady_clock::now();
571#ifdef DEBUG_CANN_MALLOC
572 GGML_LOG_INFO(
573 "cann pool[%d]: return %p, "
574 "pool_size = %5u MB\n",
575 device, b.ptr, (uint32_t) (GGML_PAD(pool_size, 1048576) / 1048576));
576#endif
577 return;
578 }
579 GGML_ABORT("cann pool[%d]: slots full\n", device);
580 }
581};
582
583/**
584 * @brief A pool of CANN buffers with virtual memory.
585 *
586 * This class manages a pool of CANN buffers with virtual memory for a specific
587 * device.
588 */
589struct ggml_cann_pool_vmm : public ggml_cann_pool {
590 /**
591 * @brief The maximum size of the virtual memory pool (32 GB).
592 */
593 size_t max_size;
594
595 /**
596 * @brief The device ID associated with this buffer pool.
597 */
598 int device;
599
600 /**
601 * @brief Pointer to the start of the virtual memory pool.
602 */
603 void * pool_addr = 0;
604
605 /**
606 * @brief Amount of virtual memory used in the pool.
607 */
608 size_t pool_used = 0;
609
610 /**
611 * @brief Total size of the virtual memory pool.
612 */
613 size_t pool_size = 0;
614
615 /**
616 * @brief Allocation granularity for the virtual memory pool.
617 */
618 size_t granularity;
619
620 /**
621 * @brief Handles for the physical memory allocated.
622 */
623 std::vector<aclrtDrvMemHandle> handles;
624
625 /**
626 * @brief Offsets for the mapped memory regions.
627 */
628 std::vector<void *> map_offsets;
629
630 /**
631 * @brief Constructor to initialize the buffer pool with virtual memory for
632 * a specific device.
633 *
634 * @param device The device ID to associate with this buffer pool.
635 */
636 explicit ggml_cann_pool_vmm(int device) : device(device) {
637 auto dev = ggml_cann_info().devices[device];
638 granularity = dev.vmm_granularity;
639 max_size = dev.total_vram;
640 }
641
642 /**
643 * @brief Destructor to free all buffers in the virtual memory pool.
644 */
645 ~ggml_cann_pool_vmm() {
646 if (pool_addr != 0) {
647 for (auto & offset : map_offsets) {
648 ACL_CHECK(aclrtUnmapMem(offset));
649 }
650 for (auto & handle : handles) {
651 ACL_CHECK(aclrtFreePhysical(handle));
652 }
653 ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
654 }
655 }
656
657 /**
658 * @brief Allocate a buffer of the given size in the virtual memory pool.
659 *
660 * @param size The size of the buffer to allocate.
661 * @param actual_size A pointer to a variable to receive the actual size of
662 * the allocated buffer.
663 * @return A pointer to the allocated buffer.
664 */
665 void * alloc(size_t size, size_t * actual_size) override {
666 // round up the allocation size to the alignment to ensure that all
667 // allocations are aligned for all data types
668 const size_t alignment = 128;
669 size = GGML_PAD(size, alignment);
670 if (size == 0) {
671 size = alignment;
672 }
673
674 size_t avail = pool_size - pool_used;
675
676 if (size > avail) {
677 // round up to the next multiple of the granularity
678 size_t reserve_size = size - avail;
679 reserve_size = GGML_PAD(reserve_size, granularity);
680
681 GGML_ASSERT(pool_size + reserve_size <= max_size);
682
683 // allocate more physical memory
684 aclrtPhysicalMemProp prop = {};
685 prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
686 prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
687 prop.memAttr = ACL_HBM_MEM_HUGE;
688 prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
689 prop.location.id = device;
690 prop.reserve = 0;
691 aclrtDrvMemHandle handle;
692 ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
693
694 // reserve virtual address space (if not already reserved)
695 if (pool_addr == 0) {
696 ACL_CHECK(aclrtReserveMemAddress(&pool_addr, max_size, 0, NULL, 1));
697 }
698
699 // map at the end of the pool
700 ACL_CHECK(aclrtMapMem((char *) pool_addr + pool_size, reserve_size, 0, handle, 0));
701
702 handles.push_back(handle);
703 map_offsets.push_back((char *) pool_addr + pool_size);
704
705 // add to the pool
706 pool_size += reserve_size;
707
708#ifdef DEBUG_CANN_MALLOC
709 GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n", device,
710 (unsigned long long) (pool_size / 1024 / 1024),
711 (unsigned long long) (reserve_size / 1024 / 1024));
712#endif
713 }
714
715 GGML_ASSERT(pool_addr != 0);
716
717 void * ptr = (void *) ((char *) pool_addr + pool_used);
718 *actual_size = size;
719 pool_used += size;
720
721#ifdef DEBUG_CANN_MALLOC
722 GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device, (unsigned long long) size,
723 (unsigned long long) ptr);
724#endif
725 return ptr;
726 }
727
728 /**
729 * @brief Free a buffer and return it to the virtual memory pool.
730 *
731 * @param ptr Pointer to the buffer to free.
732 * @param size Size of the buffer to free.
733 */
734 void free(void * ptr, size_t size) override {
735#ifdef DEBUG_CANN_MALLOC
736 GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device, (unsigned long long) size,
737 (unsigned long long) ptr);
738#endif
739
740 pool_used -= size;
741
742 // all deallocations must be in reverse order of the allocations
743 GGML_ASSERT(ptr == (void *) ((char *) pool_addr + pool_used));
744 }
745};
746
747/**
748 * @brief Create a new CANN pool for a specific device.
749 *
750 * Factory method to create a new CANN pool object based on the device type.
751 *
752 * @param device The device ID for which to create the pool.
753 * @return A unique pointer to the created CANN pool.
754 */
755std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
756 std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or("");
757
758 if (mem_pool_type == "prio") {
759 GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
760 return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf_prio(device));
761 }
762
763 if (ggml_cann_info().devices[device].vmm && mem_pool_type != "leg") {
764 GGML_LOG_INFO("%s: device %d use vmm pool\n", __func__, device);
765 return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
766 }
767
768 GGML_LOG_INFO("%s: device %d use buffer pool\n", __func__, device);
769 return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_buf(device));
770}
771
772// cann buffer
773/**
774 * @brief Context for managing a CANN buffer associated with a specific device.
775 *
776 * This structure holds information about a CANN buffer, including the device
777 * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
778 */
779struct ggml_backend_cann_buffer_context {
780 int32_t device; ///< The device ID associated with this buffer context.
781 void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer.
782
783 /**
784 * @brief Constructor to initialize the CANN buffer context.
785 *
786 * @param device The device ID associated with this buffer context.
787 * @param dev_ptr Pointer to the device memory allocated for the buffer.
788 */
789 ggml_backend_cann_buffer_context(int32_t device, void * dev_ptr) : device(device), dev_ptr(dev_ptr) {}
790
791 /**
792 * @brief Destructor to free the device memory allocated for the buffer.
793 */
794 ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
795};
796
797// cann buffer type
798/**
799 * @brief Structure representing context information for a specific backend
800 * buffer type.
801 */
802struct ggml_backend_cann_buffer_type_context {
803 int32_t device; /**< Device identifier associated with the buffer context. */
804 std::string name; /**< Name associated with the buffer context. */
805};
806
807/**
808 * @brief Retrieves the name associated with a CANN buffer type.
809 *
810 * This function returns the descriptive name associated with the specified
811 * CANN buffer type context.
812 *
813 * @param buft Pointer to the buffer type context.
814 * @return Const pointer to the C-style string containing the name.
815 */
816static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
817 ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
818
819 return buft_ctx->name.c_str();
820}
821
822/**
823 * @brief Checks if the backend buffer type is associated with the CANN backend.
824 *
825 * This function checks whether the provided backend buffer type is associated
826 * with the CANN backend based on the comparison of its name retrieval function
827 * pointer.
828 *
829 * @param buft Pointer to the backend buffer type to check.
830 * @return bool Returns true if the buffer type is associated with the CANN
831 * backend, otherwise false.
832 */
833static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
834 return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
835}
836
837/**
838 * @brief Free resources associated with a CANN buffer.
839 *
840 * This function frees the resources associated with a CANN buffer, including
841 * its context.
842 *
843 * @param buffer The CANN buffer to free.
844 */
845static void ggml_backend_cann_buffer_free_buffer(ggml_backend_buffer_t buffer) {
846 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
847 delete ctx;
848}
849
850/**
851 * @brief Retrieve the base pointer of a CANN buffer.
852 *
853 * This function returns the base pointer of a CANN buffer, which points to the
854 * device memory allocated for the buffer.
855 *
856 * @param buffer The CANN buffer whose base pointer is to be retrieved.
857 * @return A pointer to the base of the device memory allocated for the buffer.
858 */
859static void * ggml_backend_cann_buffer_get_base(ggml_backend_buffer_t buffer) {
860 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
861 return ctx->dev_ptr;
862}
863
864/**
865 * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
866 * processing.
867 *
868 * This function transforms quantized Q4.0 tensor data into a format suitable
869 * for CANN processing. It extracts quantization values and scales from the
870 * source data and prepares them in a format expected by CANN operations.
871 *
872 * @param tensor Pointer to the tensor information.
873 * @param src Pointer to the source data in Q4.0 format.
874 * @param dst Pointer to the destination buffer where transformed data will be
875 * stored.
876 */
877static void ggml_backend_cann_transform_q4_0(ggml_tensor * tensor, const void * src, void * dst) {
878 int64_t n_elems = ggml_nelements(tensor);
879 int64_t groups = n_elems / QK4_0;
880 size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
881
882 uint8_t * quant_offset = (uint8_t *) dst;
883 uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
884
885 for (int i = 0; i < groups; i++) {
886 const block_q4_0 * group = (const block_q4_0 *) ((const char *) src + i * sizeof(block_q4_0));
887 *scale_offset = group->d;
888 scale_offset++;
889
890 // 0-15
891 for (int j = 0; j < QK4_0 / 2; j += 2) {
892 (*quant_offset) = (group->qs[j] & 0x0F);
893 (*quant_offset) |= ((group->qs[j + 1] << 4));
894 quant_offset++;
895 }
896
897 // 16-31
898 for (int j = 0; j < QK4_0 / 2; j += 2) {
899 (*quant_offset) = (group->qs[j] >> 4);
900 (*quant_offset) |= (group->qs[j + 1] & 0xF0);
901 quant_offset++;
902 }
903 }
904
905 // put (uint4b_t -8) into int4b_t
906 for (quant_offset = (uint8_t *) dst; quant_offset < (uint8_t *) dst + quant_bytes; quant_offset++) {
907 (*quant_offset) ^= 0x88;
908 }
909}
910
911/**
912 * @brief Transform CANN processed data back into quantized Q4.0 format.
913 *
914 * This function transforms CANN processed data back into quantized Q4.0 format.
915 * It reverses the transformation performed by
916 * ggml_backend_cann_transform_q4_0(), converting the data back into its
917 * original quantized form.
918 *
919 * @param tensor Pointer to the tensor information.
920 * @param src Pointer to the source buffer containing transformed data.
921 * @param dst Pointer to the destination buffer where the Q4.0 formatted data
922 * will be stored.
923 */
924static void ggml_backend_cann_transform_back_q4_0(const ggml_tensor * tensor, void * src, void * dst) {
925 int64_t n_elems = ggml_nelements(tensor);
926 int64_t groups = n_elems / QK4_0;
927 size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
928
929 uint8_t * quant_offset = (uint8_t *) src;
930 uint16_t * scale_offset = (uint16_t *) ((char *) src + quant_bytes);
931
932 for (; quant_offset < (uint8_t *) src + quant_bytes; quant_offset++) {
933 (*quant_offset) ^= 0x88;
934 }
935 quant_offset = (uint8_t *) src;
936
937 for (int i = 0; i < groups; i++) {
938 block_q4_0 * group = (block_q4_0 *) ((char *) dst + i * sizeof(block_q4_0));
939 group->d = *scale_offset;
940 scale_offset++;
941
942 // 0-15
943 for (int j = 0; j < QK4_0 / 2; j += 2) {
944 group->qs[j] = ((*quant_offset) & 0x0F);
945 group->qs[j + 1] = ((*quant_offset) >> 4);
946 quant_offset++;
947 }
948
949 // 16-31
950 for (int j = 0; j < QK4_0 / 2; j += 2) {
951 group->qs[j] |= ((*quant_offset) << 4);
952 group->qs[j + 1] |= ((*quant_offset) & 0xF0);
953 quant_offset++;
954 }
955 }
956}
957
958/**
959 * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
960 * processing.
961 *
962 * This function transforms quantized Q8.0 tensor data into a format suitable
963 * for CANN processing. It extracts quantization values and scales from the
964 * source data and prepares them in a format expected by CANN operations.
965 *
966 * @param tensor Pointer to the tensor information.
967 * @param src Pointer to the source data in Q8.0 format.
968 * @param dst Pointer to the destination buffer where transformed data will be
969 * stored.
970 */
971static void ggml_backend_cann_transform_q8_0(ggml_tensor * tensor, const void * src, void * dst) {
972 int64_t n_elems = ggml_nelements(tensor);
973 int64_t groups = n_elems / QK8_0;
974 size_t quant_bytes = n_elems * sizeof(uint8_t);
975
976 uint8_t * quant_offset = (uint8_t *) dst;
977 uint16_t * scale_offset = (uint16_t *) ((char *) dst + quant_bytes);
978
979 for (int i = 0; i < groups; i++) {
980 const block_q8_0 * group = (const block_q8_0 *) ((const char *) src + i * sizeof(block_q8_0));
981 *scale_offset = group->d;
982 scale_offset++;
983 size_t group_quant_size = QK8_0 * sizeof(uint8_t);
984 memcpy(quant_offset, group->qs, group_quant_size);
985 quant_offset += group_quant_size;
986 }
987}
988
989/**
990 * @brief Transform CANN processed data back into quantized Q8.0 format.
991 *
992 * This function transforms CANN processed data back into quantized Q8.0 format.
993 * It reverses the transformation performed by
994 * ggml_backend_cann_transform_q8_0(), converting the data back into its
995 * original quantized form.
996 *
997 * @param tensor Pointer to the tensor information.
998 * @param src Pointer to the source buffer containing transformed data.
999 * @param dst Pointer to the destination buffer where the Q8.0 formatted data
1000 * will be stored.
1001 */
1002static void ggml_backend_cann_transform_back_q8_0(const ggml_tensor * tensor, const void * src, void * dst) {
1003 int64_t n_elems = ggml_nelements(tensor);
1004 int64_t groups = n_elems / QK8_0;
1005 size_t quant_bytes = n_elems * sizeof(uint8_t);
1006
1007 const uint8_t * quant_offset = (const uint8_t *) src;
1008 const uint16_t * scale_offset = (const uint16_t *) ((const char *) src + quant_bytes);
1009
1010 for (int i = 0; i < groups; i++) {
1011 block_q8_0 * group = (block_q8_0 *) ((char *) dst + i * sizeof(block_q8_0));
1012 group->d = *scale_offset;
1013 scale_offset++;
1014 size_t group_quant_size = QK8_0 * sizeof(uint8_t);
1015 memcpy(group->qs, quant_offset, group_quant_size);
1016 quant_offset += group_quant_size;
1017 }
1018}
1019
1020/**
1021 * @brief Transform tensor data based on its type for CANN processing.
1022 *
1023 * This function transforms tensor data based on its quantization type for CANN
1024 * processing. It dispatches the transformation based on the tensor's type to
1025 * specialized functions handling Q4.0 and Q8.0 formats.
1026 *
1027 * @param tensor Pointer to the tensor information.
1028 * @param src Pointer to the source data to be transformed.
1029 * @param dst Pointer to the destination buffer where transformed data will be
1030 * stored.
1031 */
1032static void ggml_backend_cann_transform(ggml_tensor * tensor, const void * src, void * dst) {
1033 switch (tensor->type) {
1034 case GGML_TYPE_Q4_0:
1035 ggml_backend_cann_transform_q4_0(tensor, src, dst);
1036 break;
1037 case GGML_TYPE_Q8_0:
1038 ggml_backend_cann_transform_q8_0(tensor, src, dst);
1039 break;
1040 default:
1041 break;
1042 }
1043}
1044
1045/**
1046 * @brief Transform CANN processed data back into tensor data based on its type.
1047 *
1048 * This function transforms CANN processed data back into tensor data based on
1049 * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
1050 * transformation based on the tensor's type to specialized functions.
1051 *
1052 * @param tensor Pointer to the tensor information.
1053 * @param src Pointer to the source data containing CANN processed data.
1054 * @param dst Pointer to the destination buffer where transformed tensor data
1055 * will be stored.
1056 */
1057static void ggml_backend_cann_transform_back(const ggml_tensor * tensor, void * src, void * dst) {
1058 switch (tensor->type) {
1059 case GGML_TYPE_Q4_0:
1060 ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
1061 break;
1062 case GGML_TYPE_Q8_0:
1063 ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
1064 break;
1065 default:
1066 break;
1067 }
1068}
1069
1070/**
1071 * @brief Check if transformation is needed for a given tensor type.
1072 *
1073 * This function checks if transformation is needed for a given tensor type
1074 * to prepare data for CANN processing.
1075 *
1076 * @param type The tensor type to check.
1077 * @return true if transformation is needed, false otherwise.
1078 */
1079static bool need_transform(ggml_type type) {
1080 switch (type) {
1081 case GGML_TYPE_Q4_0:
1082 case GGML_TYPE_Q8_0:
1083 return true;
1084 default:
1085 return false;
1086 }
1087}
1088
1089/**
1090 * @brief Initialize a tensor using data from a CANN buffer.
1091 *
1092 * This function initializes a tensor using data from a CANN buffer.
1093 * It handles special cases such as views and quantization.
1094 *
1095 * @param buffer The CANN buffer from which to initialize the tensor.
1096 * @param tensor Pointer to the tensor to be initialized.
1097 */
1098static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
1099 if (tensor->view_src != NULL && tensor->view_offs == 0) {
1100 GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
1101 return GGML_STATUS_SUCCESS;
1102 }
1103
1104 // TODO: cann backend doesn't support quantized yet. Just leave the code
1105 // here.
1106 if (ggml_is_quantized(tensor->type)) {
1107 // Initialize padding to 0 to avoid possible NaN values
1108 size_t original_size = ggml_nbytes(tensor);
1109 size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
1110
1111 if (padded_size > original_size && tensor->view_src == nullptr) {
1112 size_t memset_size = padded_size - original_size;
1113 ACL_CHECK(aclrtMemset((char *) tensor->data + original_size, memset_size, 0, memset_size));
1114 }
1115 }
1116 return GGML_STATUS_SUCCESS;
1117}
1118
1119/**
1120 * @brief Workspace for caching NZ buffers per device.
1121 *
1122 * This struct manages a device buffer used in NZ computations. It supports
1123 * allocation, reallocation, and clearing of cached memory. The struct is
1124 * designed to be used with a global array, one per device.
1125 */
1126struct ggml_cann_nz_workspace {
1127 void * ptr; // Pointer to allocated device buffer
1128 size_t allocated; // Size of currently allocated buffer in bytes
1129
1130 /**
1131 * @brief Constructor. Initializes the workspace with no allocated memory.
1132 */
1133 ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
1134
1135 /**
1136 * @brief Free cached memory and reset the workspace.
1137 *
1138 * If a buffer has been allocated, this function releases it using
1139 * aclrtFree and resets internal state.
1140 */
1141 void clear() {
1142 if (ptr) {
1143 ACL_CHECK(aclrtFree(ptr));
1144 ptr = nullptr;
1145 allocated = 0;
1146 }
1147 }
1148
1149 /**
1150 * @brief Allocate or reallocate the workspace buffer.
1151 *
1152 * If the requested size is larger than the currently allocated size,
1153 * the old buffer will be freed and a new buffer of the requested size
1154 * will be allocated on the device.
1155 *
1156 * @param new_size Size in bytes to allocate for the workspace.
1157 */
1158 void realloc(size_t new_size) {
1159 if (new_size > allocated) {
1160 clear();
1161 ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
1162 allocated = new_size;
1163 }
1164 }
1165
1166 /**
1167 * @brief Get the device buffer pointer.
1168 *
1169 * @return Pointer to the allocated buffer, or nullptr if not allocated.
1170 */
1171 void * get() const { return ptr; }
1172};
1173
1174/**
1175 * @brief Global array of NZ workspaces, one per device.
1176 */
1177static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
1178
1179/**
1180 * @brief Convert tensor weights to NZ format using Ascend CANN API.
1181 *
1182 * This function creates a transposed tensor descriptor and performs the
1183 * TransMatmulWeight operation. Converting tensor formats can significantly
1184 * improve performance on certain hardware.
1185 *
1186 * @param tensor Pointer to the input ggml_tensor containing the weights.
1187 * @param offset Byte offset within the tensor data buffer where weights start.
1188 * @param device device id.
1189 *
1190 * @note The workspace buffer used in this function is managed globally and reused
1191 * across calls. This reduces overhead from repeated memory allocation and deallocation.
1192 */
1193static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) {
1194 acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset);
1195 uint64_t workspaceSize = 0;
1196 aclOpExecutor * executor;
1197
1198 // TransMatmulWeight
1199 ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor));
1200 // Avoid frequent malloc/free of the workspace.
1201 g_nz_workspaces[device].realloc(workspaceSize);
1202
1203 void * g_nz_workspace = g_nz_workspaces[device].get();
1204
1205 ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
1206}
1207
1208// TODO: need handle tensor which has paddings.
1209/**
1210 * @brief Set tensor data in a CANN buffer.
1211 *
1212 * This function sets tensor data in a CANN buffer, handling transformations
1213 * if needed based on the tensor's type.
1214 *
1215 * @param buffer The CANN buffer where the tensor data will be set.
1216 * @param tensor Pointer to the tensor whose data will be set.
1217 * @param data Pointer to the source data to be copied into the tensor.
1218 * @param offset Offset in the source data from where to start copying.
1219 * @param size Size of the data to be copied, in bytes.
1220 */
1221static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
1222 ggml_tensor * tensor,
1223 const void * data,
1224 size_t offset,
1225 size_t size) {
1226 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1227
1228 ggml_cann_set_device(ctx->device);
1229 // TODO: refer to cann(#6017), it use thread's default stream.
1230 // For acl, synchronous functions use this default stream.
1231 // Why aclrtSynchronizeDevice?
1232
1233 // Only check env once.
1234 static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1235 if (!need_transform(tensor->type)) {
1236 ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
1237 if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1238 GGML_ASSERT(tensor->ne[2] == 1);
1239 GGML_ASSERT(tensor->ne[3] == 1);
1240 weight_format_to_nz(tensor, offset, ctx->device);
1241 }
1242 } else {
1243 void * transform_buffer = malloc(size);
1244 ggml_backend_cann_transform(tensor, data, transform_buffer);
1245
1246 ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE));
1247 free(transform_buffer);
1248 }
1249}
1250
1251/**
1252 * @brief Get tensor data from a CANN buffer.
1253 *
1254 * This function retrieves tensor data from a CANN buffer, handling
1255 * transformations if needed based on the tensor's type.
1256 *
1257 * @param buffer The CANN buffer from which to retrieve tensor data.
1258 * @param tensor Pointer to the tensor whose data will be retrieved.
1259 * @param data Pointer to the destination buffer where the tensor data will be
1260 * copied.
1261 * @param offset Offset in the destination buffer where to start copying.
1262 * @param size Size of the data to be copied, in bytes.
1263 */
1264static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
1265 const ggml_tensor * tensor,
1266 void * data,
1267 size_t offset,
1268 size_t size) {
1269 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1270
1271 ggml_cann_set_device(ctx->device);
1272
1273 if (!need_transform(tensor->type)) {
1274 ACL_CHECK(aclrtMemcpy(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1275 } else {
1276 void * transform_buffer = malloc(size);
1277 ACL_CHECK(aclrtMemcpy(transform_buffer, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST));
1278 ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1279 free(transform_buffer);
1280 }
1281}
1282
1283/**
1284 * @brief Copy tensor data between CANN buffers if possible.
1285 *
1286 * This function copies tensor data between CANN buffers if the source and
1287 * destination buffers are CANN buffers and they meet the necessary conditions
1288 * (same device or devices can access each other).
1289 *
1290 * @param buffer The destination CANN buffer where the tensor data will be
1291 * copied.
1292 * @param src Pointer to the source tensor whose data will be copied.
1293 * @param dst Pointer to the destination tensor where the data will be copied.
1294 * @return true if the copy operation succeeded, false otherwise.
1295 */
1296static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
1297 const ggml_tensor * src,
1298 ggml_tensor * dst) {
1299 if (ggml_backend_buft_is_cann(src->buffer->buft)) {
1300 ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
1301 ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1302
1303 size_t memcpy_size = ggml_nbytes(src);
1304 // Same device.
1305 if (src_ctx->device == dst_ctx->device) {
1306 ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1307 ACL_MEMCPY_DEVICE_TO_DEVICE));
1308 return true;
1309 } else {
1310#ifdef ASCEND_310P
1311 // TODO: Support 310p P2P copy
1312 return false;
1313#endif
1314 // Different device but can access by peer.
1315 int32_t canAccessPeer = 0;
1316 ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, dst_ctx->device));
1317 if (canAccessPeer) {
1318 ggml_cann_set_device(src_ctx->device);
1319 ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
1320 ACL_CHECK(aclrtMemcpy((char *) dst->data, memcpy_size, (const char *) src->data, memcpy_size,
1321 ACL_MEMCPY_DEVICE_TO_DEVICE));
1322 return true;
1323 }
1324 }
1325 }
1326 return false;
1327}
1328
1329/**
1330 * @brief Clear a CANN buffer by setting all its memory to a specified value.
1331 *
1332 * This function clears a CANN buffer by setting all its memory to a specified
1333 * value.
1334 *
1335 * @param buffer The CANN buffer to be cleared.
1336 * @param value The value to which each byte in the buffer will be set.
1337 */
1338static void ggml_backend_cann_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
1339 ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context;
1340
1341 ggml_cann_set_device(ctx->device);
1342 ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
1343}
1344
1345/**
1346 * @brief Interface for a CANN buffer in the backend.
1347 *
1348 * This structure defines function pointers to operations that can be performed
1349 * on a CANN buffer within the backend.
1350 */
1351static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
1352 /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
1353 /* .get_base = */ ggml_backend_cann_buffer_get_base,
1354 /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
1355 /* .memset_tensor = */ NULL,
1356 /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
1357 /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
1358 /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
1359 /* .clear = */ ggml_backend_cann_buffer_clear,
1360 /* .reset = */ NULL,
1361};
1362
1363/**
1364 * @brief Allocates a new CANN buffer of the specified type and size.
1365 *
1366 * This function allocates a new CANN buffer on the specified device with the
1367 * given size.
1368 *
1369 * @param buft Pointer to the buffer type context.
1370 * @param size Size in bytes of the buffer to allocate.
1371 * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1372 */
1373static ggml_backend_buffer_t ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1374 ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
1375
1376 ggml_cann_set_device(buft_ctx->device);
1377
1378 const size_t alignment = 128;
1379 size = GGML_PAD(size, alignment);
1380 if (size == 0) {
1381 size = alignment;
1382 }
1383 void * dev_ptr;
1384 aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1385 if (err != ACL_SUCCESS) {
1386 GGML_LOG_ERROR("%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n", __func__,
1387 size / 1024.0 / 1024.0, buft_ctx->device, aclGetRecentErrMsg());
1388 return nullptr;
1389 }
1390
1391 ggml_backend_cann_buffer_context * ctx = new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1392
1393 return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface, ctx, size);
1394}
1395
1396/**
1397 * @brief Retrieves the memory alignment requirement for CANN buffers of this
1398 * type.
1399 *
1400 * This function returns the alignment requirement in bytes for memory allocated
1401 * by the CANN buffer type.
1402 *
1403 * @param buft Pointer to the buffer type context (unused in this
1404 * implementation).
1405 * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1406 * buffers).
1407 */
1408static size_t ggml_backend_cann_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
1409 return 128;
1410
1411 GGML_UNUSED(buft);
1412}
1413
1414/**
1415 * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1416 *
1417 * Computes the total allocation size needed for storing the tensor's data in a
1418 * CANN buffer, considering any necessary padding or adjustments for quantized
1419 * types.
1420 *
1421 * @param buft Pointer to the buffer type context (unused in this
1422 * implementation).
1423 * @param tensor Pointer to the tensor for which the allocation size is
1424 * calculated.
1425 * @return The total allocation size in bytes required for the tensor in the
1426 * CANN buffer.
1427 */
1428static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft,
1429 const ggml_tensor * tensor) {
1430 size_t size = ggml_nbytes(tensor);
1431 int64_t ne0 = tensor->ne[0];
1432
1433 // Only check env once.
1434 static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on"));
1435
1436 // last line must bigger than 32, because every single op deal at
1437 // least 32 bytes.
1438 // TODO: quantized type?
1439 // int64_t line_size = ne0 * ggml_element_size(tensor);
1440 // int64_t line_size_align_32 = (line_size + 31) & ~31;
1441 // size += (line_size_align_32 - line_size);
1442 if (ggml_is_quantized(tensor->type)) {
1443 if (ne0 % MATRIX_ROW_PADDING != 0) {
1444 size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1445 }
1446 } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
1447 // NZ format weight are not support quantized yet.
1448 // If ND tensor transform to NZ, size may changed.
1449 int64_t shape[] = { tensor->ne[1], tensor->ne[0] };
1450 GGML_ASSERT(tensor->ne[2] == 1);
1451 GGML_ASSERT(tensor->ne[3] == 1);
1452 const aclIntArray * acl_shape = aclCreateIntArray(shape, 2);
1453 size_t new_size;
1454 ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(acl_shape, ggml_cann_type_mapping(tensor->type), &new_size));
1455 ACL_CHECK(aclDestroyIntArray(acl_shape));
1456 size = std::max(size, new_size);
1457 }
1458
1459 return size;
1460
1461 GGML_UNUSED(buft);
1462}
1463
1464static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1465 return false;
1466
1467 GGML_UNUSED(buft);
1468}
1469
1470/**
1471 * @brief Interface for managing CANN buffer types in the GGML backend.
1472 *
1473 * Provides function pointers for allocating, querying properties, and managing
1474 * memory for CANN buffer types in the GGML backend.
1475 */
1476static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1477 /* .get_name = */ ggml_backend_cann_buffer_type_name,
1478 /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1479 /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1480 /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1481 /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1482 /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
1483};
1484
1485/**
1486 * @brief Retrieves the CANN buffer type for a specified device.
1487 *
1488 * This function initializes and returns the buffer type interface associated
1489 * with the given device. It ensures thread-safe access using a mutex.
1490 *
1491 * @param device The device index for which to retrieve the buffer type.
1492 * @return A pointer to the buffer type interface for the specified device, or
1493 * nullptr if the device index is out of range.
1494 */
1495ggml_backend_buffer_type_t ggml_backend_cann_buffer_type(int32_t device) {
1496 static std::mutex mutex;
1497 std::lock_guard<std::mutex> lock(mutex);
1498
1499 if (device >= ggml_backend_cann_get_device_count()) {
1500 return nullptr;
1501 }
1502
1503 static ggml_backend_buffer_type ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1504
1505 static bool ggml_backend_cann_buffer_type_initialized = false;
1506
1507 if (!ggml_backend_cann_buffer_type_initialized) {
1508 for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {
1509 ggml_backend_cann_buffer_types[i] = {
1510 /* .iface = */ ggml_backend_cann_buffer_type_interface,
1511 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
1512 /* .context = */
1513 new ggml_backend_cann_buffer_type_context{ i, "CANN" + std::to_string(i) },
1514 };
1515 }
1516 ggml_backend_cann_buffer_type_initialized = true;
1517 }
1518
1519 return &ggml_backend_cann_buffer_types[device];
1520}
1521
1522/**
1523 * @brief Retrieves the name associated with a CANN host buffer type.
1524 *
1525 * This function returns the descriptive name associated with the specified
1526 * CANN host buffer type context.
1527 *
1528 * @param buft Pointer to the host buffer type context.
1529 * @return Const pointer to the C-style string containing the name.
1530 */
1531static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1532 return "CANN_Host";
1533
1534 GGML_UNUSED(buft);
1535}
1536
1537/**
1538 * @brief Retrieves the name associated with a CANN host buffer.
1539 *
1540 * This function returns the descriptive name associated with the specified
1541 * CANN host buffer context.
1542 *
1543 * @param buft Pointer to the host buffer context.
1544 * @return Const pointer to the C-style string containing the name.
1545 */
1546static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
1547 return "CANN_Host";
1548
1549 GGML_UNUSED(buffer);
1550}
1551
1552/**
1553 * @brief Free resources associated with a CANN host buffer.
1554 *
1555 * This function frees the resources associated with a CANN host buffer, including
1556 * its context.
1557 *
1558 * @param buffer The CANN host buffer to free.
1559 */
1560static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
1561 ACL_CHECK(aclrtFreeHost(buffer->context));
1562}
1563
1564/**
1565 * @brief Allocates a new CANN host buffer of the specified size.
1566 *
1567 * This function allocates a new CANN host buffer with the given size.
1568 * @param size Size in bytes of the host buffer to allocate.
1569 * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
1570 */
1571static void * ggml_cann_host_malloc(size_t size) {
1572 if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
1573 return nullptr;
1574 }
1575
1576 const size_t alignment = 128;
1577 size = GGML_PAD(size, alignment);
1578 if (size == 0) {
1579 size = alignment;
1580 }
1581
1582 void * hostPtr = nullptr;
1583 aclError err = aclrtMallocHost((void **) &hostPtr, size);
1584 if (err != ACL_SUCCESS) {
1585 GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__, size / 1024.0 / 1024.0,
1586 aclGetRecentErrMsg());
1587 return nullptr;
1588 }
1589 return hostPtr;
1590}
1591
1592/**
1593 * @brief Allocates a new CANN host buffer of the specified type and size.
1594 *
1595 * @param buft Pointer to the host buffer type context.
1596 * @param size Size in bytes of the host buffer to allocate.
1597 * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1598 */
1599static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1600 size_t size) {
1601 void * hostPtr = ggml_cann_host_malloc(size);
1602
1603 if (hostPtr == nullptr) {
1604 // fallback to cpu buffer
1605 return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1606 }
1607
1608 ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1609 buffer->buft = buft;
1610 buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1611
1612 return buffer;
1613}
1614
1615/**
1616 * @brief Interface for managing CANN host buffer types in the GGML backend.
1617 *
1618 * Provides function pointers for allocating, querying properties, and managing
1619 * memory for CANN buffer types in the GGML backend.
1620 */
1621ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1622 static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1623 /* .iface = */ {
1624 /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1625 /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1626 /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1627 /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1628 /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1629 /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1630 },
1631 /* .device = */
1632 ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1633 /* .context = */ nullptr,
1634 };
1635
1636 return &ggml_backend_cann_buffer_type_host;
1637}
1638
1639/**
1640 * @brief Computes the forward operation for a given tensor using CANN
1641 * operations.
1642 *
1643 * This function selects the appropriate CANN operation based on the type of
1644 * operation specified in the tensor and performs the computation.
1645 *
1646 * @param ctx The CANN context containing necessary resources and
1647 * configurations.
1648 * @param dst The destination tensor where the result of the computation will be
1649 * stored.
1650 * @return true if the computation was successful; false otherwise.
1651 */
1652static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct ggml_tensor * dst) {
1653 switch (dst->op) {
1654 case GGML_OP_REPEAT:
1655 ggml_cann_repeat(ctx, dst);
1656 break;
1657 case GGML_OP_GET_ROWS:
1658 ggml_cann_get_rows(ctx, dst);
1659 break;
1660 case GGML_OP_SET_ROWS:
1661 ggml_cann_set_rows(ctx, dst);
1662 break;
1663 case GGML_OP_DUP:
1664 ggml_cann_dup(ctx, dst);
1665 break;
1666 case GGML_OP_ADD:
1667 case GGML_OP_ADD1:
1668 ggml_cann_binary_op<aclnn_add>(ctx, dst);
1669 break;
1670 case GGML_OP_SUB:
1671 ggml_cann_binary_op<aclnn_sub>(ctx, dst);
1672 break;
1673 case GGML_OP_ACC:
1674 ggml_cann_acc(ctx, dst);
1675 break;
1676 case GGML_OP_MUL:
1677 ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1678 break;
1679 case GGML_OP_DIV:
1680 ggml_cann_binary_op<aclnn_div>(ctx, dst);
1681 break;
1682 case GGML_OP_UNARY:
1683 switch (ggml_get_unary_op(dst)) {
1684 case GGML_UNARY_OP_ABS:
1685 GGML_CANN_CALL_OP_UNARY(Abs);
1686 break;
1687 case GGML_UNARY_OP_NEG:
1688 GGML_CANN_CALL_OP_UNARY(Neg);
1689 break;
1690 case GGML_UNARY_OP_GELU:
1691 case GGML_UNARY_OP_GELU_ERF:
1692 // aclnnGelu internally uses the erf-based approximation.
1693 GGML_CANN_CALL_OP_UNARY(Gelu);
1694 break;
1695 case GGML_UNARY_OP_SILU:
1696 GGML_CANN_CALL_OP_UNARY(Silu);
1697 break;
1698 case GGML_UNARY_OP_GELU_QUICK:
1699 {
1700 auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1701 GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1702 };
1703 ggml_cann_op_unary(lambda, ctx, dst);
1704 }
1705 break;
1706 case GGML_UNARY_OP_TANH:
1707 GGML_CANN_CALL_OP_UNARY(Tanh);
1708 break;
1709 case GGML_UNARY_OP_RELU:
1710 GGML_CANN_CALL_OP_UNARY(Relu);
1711 break;
1712 case GGML_UNARY_OP_SIGMOID:
1713 GGML_CANN_CALL_OP_UNARY(Sigmoid);
1714 break;
1715 case GGML_UNARY_OP_HARDSIGMOID:
1716 GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
1717 break;
1718 case GGML_UNARY_OP_HARDSWISH:
1719 GGML_CANN_CALL_OP_UNARY(Hardswish);
1720 break;
1721 case GGML_UNARY_OP_EXP:
1722 GGML_CANN_CALL_OP_UNARY(Exp);
1723 break;
1724 case GGML_UNARY_OP_ELU:
1725 ggml_cann_elu(ctx, dst);
1726 break;
1727 case GGML_UNARY_OP_SGN:
1728 GGML_CANN_CALL_OP_UNARY(Sign);
1729 break;
1730 case GGML_UNARY_OP_STEP:
1731 ggml_cann_step(ctx, dst);
1732 break;
1733 default:
1734 return false;
1735 }
1736 break;
1737 case GGML_OP_GLU:
1738 switch (ggml_get_glu_op(dst)) {
1739 case GGML_GLU_OP_REGLU:
1740 GGML_CANN_CALL_OP_UNARY_GATED(Relu);
1741 break;
1742 case GGML_GLU_OP_GEGLU:
1743 case GGML_GLU_OP_GEGLU_ERF:
1744 // aclnnGelu internally uses the erf-based approximation.
1745 GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
1746 break;
1747 case GGML_GLU_OP_SWIGLU:
1748 GGML_CANN_CALL_OP_UNARY_GATED(Silu);
1749 break;
1750 case GGML_GLU_OP_GEGLU_QUICK:
1751 {
1752 auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) {
1753 GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
1754 };
1755 ggml_cann_op_unary_gated(lambda, ctx, dst);
1756 }
1757 break;
1758 default:
1759 return false;
1760 }
1761 break;
1762 case GGML_OP_NORM:
1763 ggml_cann_norm(ctx, dst);
1764 break;
1765 case GGML_OP_GROUP_NORM:
1766 ggml_cann_group_norm(ctx, dst);
1767 break;
1768 case GGML_OP_L2_NORM:
1769 ggml_cann_l2_norm(ctx, dst);
1770 break;
1771 case GGML_OP_CROSS_ENTROPY_LOSS:
1772 ggml_cann_cross_entropy_loss(ctx, dst);
1773 break;
1774 case GGML_OP_CONCAT:
1775 ggml_cann_concat(ctx, dst);
1776 break;
1777 case GGML_OP_UPSCALE:
1778 ggml_cann_upsample_nearest2d(ctx, dst);
1779 break;
1780 case GGML_OP_PAD:
1781 ggml_cann_pad(ctx, dst);
1782 break;
1783 case GGML_OP_ARANGE:
1784 ggml_cann_arange(ctx, dst);
1785 break;
1786 case GGML_OP_TIMESTEP_EMBEDDING:
1787 ggml_cann_timestep_embedding(ctx, dst);
1788 break;
1789 case GGML_OP_LEAKY_RELU:
1790 ggml_cann_leaky_relu(ctx, dst);
1791 break;
1792 case GGML_OP_RMS_NORM:
1793 ggml_cann_rms_norm(ctx, dst);
1794 break;
1795 case GGML_OP_MUL_MAT:
1796 ggml_cann_mul_mat(ctx, dst);
1797 break;
1798 case GGML_OP_MUL_MAT_ID:
1799 ggml_cann_mul_mat_id(ctx, dst);
1800 break;
1801 case GGML_OP_SCALE:
1802 ggml_cann_scale(ctx, dst);
1803 break;
1804 case GGML_OP_SQR:
1805 GGML_ASSERT(dst->src[1] == nullptr);
1806 dst->src[1] = dst->src[0];
1807 ggml_cann_binary_op<aclnn_mul>(ctx, dst);
1808 break;
1809 case GGML_OP_SQRT:
1810 GGML_CANN_CALL_OP_UNARY(Sqrt);
1811 break;
1812 case GGML_OP_CLAMP:
1813 ggml_cann_clamp(ctx, dst);
1814 break;
1815 case GGML_OP_CPY:
1816 ggml_cann_cpy(ctx, dst);
1817 break;
1818 case GGML_OP_CONT:
1819 ggml_cann_dup(ctx, dst);
1820 break;
1821 case GGML_OP_NONE:
1822 case GGML_OP_RESHAPE:
1823 case GGML_OP_VIEW:
1824 case GGML_OP_PERMUTE:
1825 case GGML_OP_TRANSPOSE:
1826 break;
1827 case GGML_OP_DIAG_MASK_INF:
1828 ggml_cann_diag_mask(ctx, dst, -INFINITY);
1829 break;
1830 case GGML_OP_SOFT_MAX:
1831 ggml_cann_softmax(ctx, dst);
1832 break;
1833 case GGML_OP_ROPE:
1834 ggml_cann_rope(ctx, dst);
1835 break;
1836 case GGML_OP_IM2COL:
1837 ggml_cann_im2col(ctx, dst);
1838 break;
1839 case GGML_OP_POOL_2D:
1840 ggml_cann_pool2d(ctx, dst);
1841 break;
1842 case GGML_OP_SUM:
1843 ggml_cann_sum(ctx, dst);
1844 break;
1845 case GGML_OP_SUM_ROWS:
1846 ggml_cann_sum_rows(ctx, dst);
1847 break;
1848 case GGML_OP_ARGSORT:
1849 ggml_cann_argsort(ctx, dst);
1850 break;
1851 case GGML_OP_ARGMAX:
1852 ggml_cann_argmax(ctx, dst);
1853 break;
1854 case GGML_OP_COS:
1855 ggml_cann_op_unary<aclnn_cos>(ctx, dst);
1856 break;
1857 case GGML_OP_SIN:
1858 ggml_cann_op_unary<aclnn_sin>(ctx, dst);
1859 break;
1860 case GGML_OP_CONV_TRANSPOSE_1D:
1861 ggml_cann_conv_transpose_1d(ctx, dst);
1862 break;
1863 case GGML_OP_LOG:
1864 GGML_CANN_CALL_OP_UNARY(Log);
1865 break;
1866 case GGML_OP_MEAN:
1867 ggml_cann_mean(ctx, dst);
1868 break;
1869 case GGML_OP_PAD_REFLECT_1D:
1870 ggml_cann_pad_reflect_1d(ctx, dst);
1871 break;
1872 case GGML_OP_COUNT_EQUAL:
1873 ggml_cann_count_equal(ctx, dst);
1874 break;
1875 case GGML_OP_FLASH_ATTN_EXT:
1876 ggml_cann_flash_attn_ext(ctx, dst);
1877 break;
1878 case GGML_OP_OUT_PROD:
1879 ggml_cann_out_prod(ctx, dst);
1880 break;
1881 case GGML_OP_GATED_LINEAR_ATTN:
1882 ggml_cann_gated_linear_attn(ctx, dst);
1883 break;
1884 case GGML_OP_SSM_CONV:
1885 ggml_cann_ssm_conv(ctx, dst);
1886 break;
1887 default:
1888 return false;
1889 }
1890
1891 return true;
1892}
1893
1894// backend
1895/**
1896 * @brief Retrieves the name associated with the CANN backend.
1897 *
1898 * This function returns the name assigned to the CANN backend, which is stored
1899 * in the context of the provided backend structure.
1900 *
1901 * @param backend Pointer to the CANN backend structure.
1902 * @return A pointer to a constant string representing the backend name.
1903 */
1904static const char * ggml_backend_cann_name(ggml_backend_t backend) {
1905 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1906
1907 return cann_ctx->name.c_str();
1908}
1909
1910/**
1911 * @brief Frees resources associated with the CANN backend.
1912 *
1913 * This function releases resources associated with the CANN backend context
1914 * and resets the device associated with the backend to its initial state.
1915 *
1916 * @param backend Pointer to the CANN backend structure to be freed.
1917 */
1918static void ggml_backend_cann_free(ggml_backend_t backend) {
1919 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1920 ACL_CHECK(aclrtSynchronizeDevice());
1921 ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1922
1923 delete cann_ctx;
1924 delete backend;
1925}
1926
1927/**
1928 * @brief Sets tensor data asynchronously in the CANN backend.
1929 *
1930 * This function asynchronously sets tensor data in the CANN backend.
1931 *
1932 * @param backend Pointer to the CANN backend structure.
1933 * @param tensor Pointer to the tensor structure to set data for.
1934 * @param data Pointer to the host data to copy to the tensor.
1935 * @param offset Offset in bytes within the host data.
1936 * @param size Size of the data to copy in bytes.
1937 */
1938static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1939 ggml_tensor * tensor,
1940 const void * data,
1941 size_t offset,
1942 size_t size) {
1943 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1944 ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1945
1946 GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
1947 GGML_ASSERT(!ggml_is_quantized(tensor->type));
1948
1949 ACL_CHECK(aclrtMemcpyAsync((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE,
1950 cann_ctx->stream()));
1951}
1952
1953/**
1954 * @brief Gets tensor data asynchronously in the CANN backend.
1955 *
1956 * This function asynchronously gets tensor data in the CANN backend.
1957 *
1958 * @param backend Pointer to the CANN backend structure.
1959 * @param tensor Pointer to the tensor structure to get data from.
1960 * @param data Pointer to the host data to copy from the tensor.
1961 * @param offset Offset in bytes within the host data.
1962 * @param size Size of the data to copy in bytes.
1963 */
1964static void ggml_backend_cann_get_tensor_async(ggml_backend_t backend,
1965 const ggml_tensor * tensor,
1966 void * data,
1967 size_t offset,
1968 size_t size) {
1969 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
1970 ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1971
1972 GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) && "unsupported buffer type");
1973 GGML_ASSERT(!ggml_is_quantized(tensor->type));
1974
1975 ACL_CHECK(aclrtMemcpyAsync(data, size, (char *) tensor->data + offset, size, ACL_MEMCPY_DEVICE_TO_HOST,
1976 cann_ctx->stream()));
1977}
1978
1979/**
1980 * @brief Asynchronously copies tensor data between CANN backends.
1981 *
1982 * This function copies tensor data asynchronously between two CANN backends. It
1983 * checks if both tensors reside in CANN buffers and whether the devices support
1984 * peer-to-peer access for direct copying. If not, it returns false.
1985 *
1986 * @param backend_src Pointer to the source CANN backend structure.
1987 * @param backend_dst Pointer to the destination CANN backend structure.
1988 * @param src Pointer to the source tensor to copy data from.
1989 * @param dst Pointer to the destination tensor to copy data to.
1990 * @return true if the copy operation succeeds, false otherwise.
1991 */
1992static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
1993 ggml_backend_t backend_dst,
1994 const ggml_tensor * src,
1995 ggml_tensor * dst) {
1996 GGML_ASSERT(ggml_backend_is_cann(backend_src) || ggml_backend_is_cann(backend_dst));
1997
1998 GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
1999
2000 if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
2001 return false;
2002 }
2003
2004 ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer;
2005 ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer;
2006
2007 ggml_backend_cann_context * cann_ctx_src = (ggml_backend_cann_context *) backend_src->context;
2008 ggml_backend_cann_context * cann_ctx_dst = (ggml_backend_cann_context *) backend_dst->context;
2009
2010 size_t copy_size = ggml_nbytes(dst);
2011 if (copy_size == 0) {
2012 return true;
2013 }
2014 if (backend_src != backend_dst) {
2015#ifdef ASCEND_310P
2016 // TODO: Support 310p P2P copy
2017 return false;
2018#endif
2019 ggml_backend_cann_buffer_context * buf_ctx_src = (ggml_backend_cann_buffer_context *) buf_src->context;
2020 ggml_backend_cann_buffer_context * buf_ctx_dst = (ggml_backend_cann_buffer_context *) buf_dst->context;
2021
2022 GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
2023 GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
2024
2025 int32_t canAccessPeer = 0;
2026 ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device, cann_ctx_dst->device));
2027 if (!canAccessPeer) {
2028 return false;
2029 }
2030
2031 // need open both directions for memcpyasync between devices.
2032 ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
2033 ggml_cann_set_device(cann_ctx_src->device);
2034 ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
2035
2036 // wait for task_queue empty to keep task order.
2037 ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2038 cann_ctx_src->stream()));
2039 // record event on src stream after the copy
2040 // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
2041 // if (!cann_ctx_src->copy_event) {
2042 // ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
2043 // }
2044 // ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
2045
2046 // // wait on dst stream for the copy to complete
2047 // ggml_cann_set_device(cann_ctx_dst->device);
2048 // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
2049 ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
2050 } else {
2051 // src and dst are on the same backend
2052 ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_MEMCPY_DEVICE_TO_DEVICE,
2053 cann_ctx_dst->stream()));
2054 }
2055
2056 return true;
2057}
2058
2059/**
2060 * @brief Synchronizes a CANN backend.
2061 *
2062 * This function synchronizes the specified CANN backend by waiting for all
2063 * operations in its associated stream to complete.
2064 *
2065 * @param backend Pointer to the CANN backend structure to synchronize.
2066 */
2067static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
2068 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2069 ggml_cann_set_device(cann_ctx->device);
2070 ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
2071}
2072
2073/**
2074 * @brief Check if CANN backend can fuse the specified operation sequence
2075 *
2076 * This function determines whether an operation sequence starting from the specified node
2077 * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
2078 * memory access overhead and improve computational efficiency.
2079 *
2080 * @param cgraph Pointer to the computation graph
2081 * @param node_idx Index of the starting node in the computation graph
2082 * @param ops Sequence of operation types to check for fusion
2083 * @return true if the operations can be fused
2084 * @return false if the operations cannot be fused
2085 */
2086static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
2087 int node_idx,
2088 std::initializer_list<enum ggml_op> ops) {
2089 if (!ggml_can_fuse(cgraph, node_idx, ops)) {
2090 return false;
2091 }
2092
2093 // CANN backend supports fusing ADD + RMS_NORM operations
2094 if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
2095 ggml_tensor * add_node = cgraph->nodes[node_idx];
2096 // TODO: support broadcast for ADD + RMS_NORM
2097 if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
2098 add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
2099 return false;
2100 }
2101 return true;
2102 }
2103
2104 return false;
2105}
2106
2107/**
2108 * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
2109 *
2110 * If CANN graph execution is enabled and graph capture is required, this function begins
2111 * graph capture, runs the graph, ends capture, and stores the captured graph.
2112 *
2113 * Otherwise, it falls back to op-by-op execution using the CANN compute kernel dispatcher.
2114 *
2115 * @param cann_ctx The CANN backend context.
2116 * @param cgraph The ggml computation graph.
2117 * @param use_cann_graph Whether to use CANN graph execution.
2118 * @param cann_graph_capture_required Whether graph capture is needed due to graph changes.
2119 */
2120static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx,
2121 ggml_cgraph * cgraph,
2122 bool use_cann_graph,
2123 bool cann_graph_capture_required) {
2124#ifdef USE_ACL_GRAPH
2125 if (use_cann_graph && cann_graph_capture_required) { // Begin CANN graph capture
2126 ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
2127 }
2128#endif // USE_ACL_GRAPH
2129 // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
2130 // With the use of CANN graphs, the execution will be performed by the graph launch.
2131 static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
2132
2133 if (!use_cann_graph || cann_graph_capture_required) {
2134 for (int i = 0; i < cgraph->n_nodes; i++) {
2135 ggml_tensor * node = cgraph->nodes[i];
2136 if (opt_fusion) {
2137 if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
2138 ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
2139 i++;
2140 continue;
2141 }
2142 }
2143
2144 if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
2145 node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
2146 continue;
2147 }
2148
2149 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
2150 continue;
2151 }
2152
2153 bool ok = ggml_cann_compute_forward(*cann_ctx, node);
2154 if (!ok) {
2155 GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
2156 }
2157 GGML_ASSERT(ok);
2158 }
2159 }
2160
2161#ifdef USE_ACL_GRAPH
2162 if (use_cann_graph) {
2163 GGML_ASSERT(!cann_ctx->graph_lru_cache.cache_list.empty());
2164 ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
2165
2166 if (cann_graph_capture_required) { // End CANN graph capture
2167 ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
2168 }
2169
2170 // Execute CANN graph
2171 ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
2172 }
2173#endif // USE_ACL_GRAPH
2174}
2175
2176/**
2177 * @brief Computes a computational graph using a CANN backend.
2178 *
2179 * This function computes the operations defined in the computational graph
2180 * using the specified CANN backend.
2181 *
2182 * @param backend Pointer to the CANN backend structure to use for computation.
2183 * @param cgraph Pointer to the computational graph structure containing nodes
2184 * representing operations to be computed.
2185 * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
2186 * completes successfully, otherwise an appropriate error status.
2187 */
2188static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
2189 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2190 ggml_cann_set_device(cann_ctx->device);
2191 g_nz_workspaces[cann_ctx->device].clear();
2192
2193 // calculate rope cache for fist layer in current device.
2194 cann_ctx->rope_cache.cached = false;
2195
2196 bool graph_capture_required = false;
2197#ifdef USE_ACL_GRAPH
2198 bool use_cann_graph = true;
2199
2200 static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
2201 if (!prefill_use_graph) {
2202 // Do not use acl_graph for prefill.
2203 for (int i = 0; i < cgraph->n_nodes; i++) {
2204 ggml_tensor * node = cgraph->nodes[i];
2205 // TODO: Optimize here. Currently, we can only
2206 // get seq_len by FA's input.
2207 if (node->op == GGML_OP_FLASH_ATTN_EXT) {
2208 // Q -> src[0], shape: [B, S, N, D]
2209 use_cann_graph = (node->src[0]->ne[1] == 1);
2210 break;
2211 }
2212 }
2213 }
2214
2215 if (!cann_ctx->acl_graph_mode) {
2216 use_cann_graph = false;
2217 }
2218
2219 if (use_cann_graph) {
2220 // If no matching graph is found, the graph needs to be recaptured.
2221 graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph);
2222 if (graph_capture_required) {
2223 // If no matching graph is found, add a new ACL graph.
2224 ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph);
2225 cann_ctx->graph_lru_cache.push(new_graph);
2226 }
2227 }
2228#else
2229 bool use_cann_graph = false;
2230#endif // USE_ACL_GRAPH
2231 evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, graph_capture_required);
2232
2233 return GGML_STATUS_SUCCESS;
2234}
2235
2236/**
2237 * @brief Checks if the CANN backend supports a specific operation.
2238 *
2239 * This function checks whether the specified operation is supported by the
2240 * CANN backend.
2241 *
2242 * @param backend Pointer to the CANN backend structure to check support for
2243 * the operation.
2244 * @param op Pointer to the tensor representing the operation to check.
2245 * @return bool Returns true if the operation is supported by the backend,
2246 * otherwise false.
2247 */
2248static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2249 switch (op->op) {
2250 case GGML_OP_UNARY:
2251 switch (ggml_get_unary_op(op)) {
2252 case GGML_UNARY_OP_ABS:
2253 case GGML_UNARY_OP_NEG:
2254 case GGML_UNARY_OP_GELU:
2255 case GGML_UNARY_OP_SILU:
2256 case GGML_UNARY_OP_RELU:
2257 case GGML_UNARY_OP_SIGMOID:
2258 case GGML_UNARY_OP_HARDSIGMOID:
2259 case GGML_UNARY_OP_HARDSWISH:
2260 case GGML_UNARY_OP_GELU_QUICK:
2261 case GGML_UNARY_OP_TANH:
2262 case GGML_UNARY_OP_EXP:
2263 case GGML_UNARY_OP_ELU:
2264 case GGML_UNARY_OP_SGN:
2265 case GGML_UNARY_OP_STEP:
2266 case GGML_UNARY_OP_GELU_ERF:
2267 return true;
2268 default:
2269 return false;
2270 }
2271 case GGML_OP_GLU:
2272 switch (ggml_get_glu_op(op)) {
2273 case GGML_GLU_OP_REGLU:
2274 case GGML_GLU_OP_GEGLU:
2275 case GGML_GLU_OP_SWIGLU:
2276 case GGML_GLU_OP_GEGLU_ERF:
2277 case GGML_GLU_OP_GEGLU_QUICK:
2278 return true;
2279 default:
2280 return false;
2281 }
2282 break;
2283 case GGML_OP_MUL_MAT:
2284 {
2285 switch (op->src[0]->type) {
2286 case GGML_TYPE_F16:
2287 case GGML_TYPE_F32:
2288 return true;
2289 case GGML_TYPE_Q8_0:
2290 case GGML_TYPE_Q4_0:
2291#ifdef ASCEND_310P
2292 // Q4 && Q8 per group is not support on 310p device
2293 return false;
2294#endif
2295 // only support contiguous for quantized types.
2296 return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2297 default:
2298 return false;
2299 }
2300 }
2301 case GGML_OP_MUL_MAT_ID:
2302 switch (op->src[0]->type) {
2303 case GGML_TYPE_F16:
2304 case GGML_TYPE_F32:
2305 return true;
2306 case GGML_TYPE_Q8_0:
2307 case GGML_TYPE_Q4_0:
2308#ifdef ASCEND_310P
2309 // Q4 && Q8 per group is not support on 310p device
2310 return false;
2311#endif
2312 // only support contiguous for quantized types.
2313 return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
2314 default:
2315 return false;
2316 }
2317 // embedding
2318 case GGML_OP_GET_ROWS:
2319 {
2320 switch (op->src[0]->type) {
2321 case GGML_TYPE_F32:
2322 case GGML_TYPE_F16:
2323 case GGML_TYPE_Q8_0:
2324 return true;
2325 default:
2326 return false;
2327 }
2328 }
2329 break;
2330 case GGML_OP_SET_ROWS:
2331 {
2332 switch (op->type) {
2333 case GGML_TYPE_F32:
2334 case GGML_TYPE_F16:
2335 return true;
2336 default:
2337 return false;
2338 }
2339 }
2340 break;
2341 case GGML_OP_CPY:
2342 {
2343 ggml_tensor * src = op->src[0];
2344 if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) ||
2345 (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) {
2346 // only support F32 and F16.
2347 return false;
2348 }
2349 return true;
2350 }
2351 break;
2352 case GGML_OP_CONT:
2353 {
2354 // TODO: support GGML_TYPE_BF16
2355 switch (op->src[0]->type) {
2356 case GGML_TYPE_F32:
2357 case GGML_TYPE_F16:
2358 return true;
2359 default:
2360 return false;
2361 }
2362 }
2363 case GGML_OP_ROPE:
2364 {
2365 if (op->src[0]->ne[0] > 896) {
2366 return false;
2367 }
2368#ifdef ASCEND_310P
2369 // TODO: Support rope_dim < ne00(dim)
2370 if (op->src[0]->ne[0] != op->op_params[1]) {
2371 return false;
2372 }
2373 if (!ggml_is_contiguous(op->src[0])) {
2374 return false;
2375 }
2376#endif
2377 return true;
2378 }
2379 case GGML_OP_UPSCALE:
2380 {
2381 // aclnnUpsampleNearest2dGetWorkspaceSize not support
2382 // selfDimN[2]/outDimN[2] or selfDimC[3]/outDimC[3] not equal
2383 if (op->src[0]->ne[2] * op->ne[3] != op->src[0]->ne[3] * op->ne[2]) {
2384 return false;
2385 }
2386 if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) {
2387 return false;
2388 }
2389 if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
2390 return false;
2391 }
2392 return true;
2393 }
2394 case GGML_OP_POOL_2D:
2395 {
2396 const int32_t * opts = (const int32_t *) op->op_params;
2397#ifdef ASCEND_310P
2398 enum ggml_op_pool opt = static_cast<ggml_op_pool>(opts[0]);
2399 if (opt == GGML_OP_POOL_MAX) {
2400 return false;
2401 }
2402#endif
2403 const int k0 = opts[1];
2404 const int k1 = opts[2];
2405 const int p0 = opts[5];
2406 const int p1 = opts[6];
2407 // value of paddingH should be at most half of kernelH
2408 // value of paddingW should be at most half of kernelW
2409 return (p0 <= (k0 / 2)) && (p1 <= (k1 / 2));
2410 }
2411 case GGML_OP_SUM:
2412 return ggml_is_contiguous_rows(op->src[0]);
2413 case GGML_OP_L2_NORM:
2414 case GGML_OP_CROSS_ENTROPY_LOSS:
2415 case GGML_OP_DUP:
2416 case GGML_OP_IM2COL:
2417 case GGML_OP_CONCAT:
2418 case GGML_OP_REPEAT:
2419 case GGML_OP_NONE:
2420 case GGML_OP_RESHAPE:
2421 case GGML_OP_VIEW:
2422 case GGML_OP_PERMUTE:
2423 case GGML_OP_TRANSPOSE:
2424 case GGML_OP_NORM:
2425 case GGML_OP_ADD:
2426 case GGML_OP_ADD1:
2427 case GGML_OP_SUB:
2428 case GGML_OP_MUL:
2429 case GGML_OP_DIV:
2430 case GGML_OP_RMS_NORM:
2431 case GGML_OP_SQR:
2432 case GGML_OP_SQRT:
2433 case GGML_OP_CLAMP:
2434 case GGML_OP_DIAG_MASK_INF:
2435 case GGML_OP_SUM_ROWS:
2436 case GGML_OP_ARGSORT:
2437 case GGML_OP_ACC:
2438 case GGML_OP_GROUP_NORM:
2439 return true;
2440 case GGML_OP_PAD:
2441 // TODO: add circular padding support for cann, see https://github.com/ggml-org/llama.cpp/pull/16985
2442 return ggml_get_op_params_i32(op, 8) == 0;
2443 case GGML_OP_ARANGE:
2444 case GGML_OP_TIMESTEP_EMBEDDING:
2445 case GGML_OP_LEAKY_RELU:
2446 case GGML_OP_ARGMAX:
2447 case GGML_OP_COS:
2448 case GGML_OP_SIN:
2449 case GGML_OP_LOG:
2450 case GGML_OP_MEAN:
2451 case GGML_OP_PAD_REFLECT_1D:
2452 case GGML_OP_COUNT_EQUAL:
2453 case GGML_OP_GATED_LINEAR_ATTN:
2454 return true;
2455 case GGML_OP_OUT_PROD:
2456 {
2457#ifdef ASCEND_310P
2458 // Ger is not supported on 310p device
2459 return false;
2460#endif
2461 switch (op->src[0]->type) {
2462 case GGML_TYPE_F16:
2463 case GGML_TYPE_F32:
2464 return true;
2465 default:
2466 return false;
2467 }
2468 }
2469 case GGML_OP_CONV_TRANSPOSE_1D:
2470 return true;
2471 case GGML_OP_SCALE:
2472 float bias;
2473 memcpy(&bias, (const float *) (op->op_params) + 1, sizeof(float));
2474 return bias == 0.0f; // TODO: support bias != 0.0f
2475 case GGML_OP_SOFT_MAX:
2476 // TODO: support attention sinks [TAG_ATTN_SINKS]
2477 if (op->src[2]) {
2478 return false;
2479 }
2480 return true;
2481 case GGML_OP_FLASH_ATTN_EXT:
2482 {
2483#ifdef ASCEND_310P
2484 // FA not support on 310p device
2485 return false;
2486#endif
2487 // derived from [ggml-cuda.cu]
2488 if (op->src[1]->type != GGML_TYPE_F16 || op->src[2]->type != GGML_TYPE_F16) {
2489 return false;
2490 }
2491 if (op->src[1]->type != GGML_TYPE_F16 && op->src[1]->type != GGML_TYPE_F32 &&
2492 op->src[1]->type != GGML_TYPE_BF16) {
2493 return false;
2494 }
2495 if (op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16) {
2496 return false;
2497 }
2498 // TODO: support attention sinks [TAG_ATTN_SINKS]
2499 if (op->src[4]) {
2500 return false;
2501 }
2502 if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
2503 // different head sizes of K and V are not supported yet
2504 return false;
2505 }
2506 if (op->src[0]->ne[0] % 16 != 0) {
2507 // TODO: padding to support
2508 return false;
2509 }
2510 float logitSoftcap = 0.0f;
2511 memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float));
2512 if (logitSoftcap != 0.0f) {
2513 return false;
2514 }
2515 return true;
2516 }
2517 case GGML_OP_SSM_CONV:
2518 return true;
2519 default:
2520 return false;
2521 }
2522
2523 GGML_UNUSED(dev);
2524}
2525
2526/**
2527 * @brief Records an event on the CANN backend stream.
2528 *
2529 * This function records the given event on the ACL runtime stream associated
2530 * with the backend context.
2531 *
2532 * @param event Pointer to the event structure to be recorded.
2533 */
2534static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
2535 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2536 ACL_CHECK(aclrtRecordEvent((aclrtEvent) event->context, cann_ctx->stream()));
2537}
2538
2539/**
2540 * @brief Waits for a recorded event to complete on the CANN backend stream.
2541 *
2542 * This function makes the given backend wait for the event to complete on its
2543 * ACL runtime stream.
2544 *
2545 * @param backend Pointer to the backend structure.
2546 * @param event Pointer to the event structure that the backend needs to wait
2547 * for.
2548 */
2549static void ggml_backend_cann_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
2550 ggml_backend_cann_context * cann_ctx = (ggml_backend_cann_context *) backend->context;
2551 if (ggml_backend_is_cann(backend)) {
2552 ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(), (aclrtEvent) event->context));
2553 } else {
2554 GGML_ABORT("fatal error");
2555 }
2556}
2557
2558/**
2559 * @brief Structure defining the interface for the CANN backend.
2560 *
2561 * This structure contains function pointers for various operations
2562 * supported by the CANN backend, including name retrieval, memory
2563 * management, tensor operations, synchronization, and event handling.
2564 */
2565static const ggml_backend_i ggml_backend_cann_interface = {
2566 /* .get_name = */ ggml_backend_cann_name,
2567 /* .free = */ ggml_backend_cann_free,
2568 /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
2569 /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
2570 /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
2571 /* .synchronize = */ ggml_backend_cann_synchronize,
2572 /* .graph_plan_create = */ NULL,
2573 /* .graph_plan_free = */ NULL,
2574 /* .graph_plan_update = */ NULL,
2575 /* .graph_plan_compute = */ NULL,
2576 /* .graph_compute = */ ggml_backend_cann_graph_compute,
2577 /* .event_record = */ ggml_backend_cann_event_record,
2578 /* .event_wait = */ ggml_backend_cann_event_wait,
2579 /* .graph_optimize = */ NULL,
2580};
2581
2582/**
2583 * @brief Return the hardcoded GUID for the CANN backend.
2584 *
2585 * This function returns a static GUID which uniquely identifies the CANN
2586 * backend.
2587 *
2588 * @return A pointer to the static GUID.
2589 */
2590static ggml_guid_t ggml_backend_cann_guid() {
2591 static ggml_guid guid = { 0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
2592 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64 };
2593 return &guid;
2594}
2595
2596// backend device
2597struct ggml_backend_cann_device_context {
2598 int device;
2599 std::string name;
2600 std::string description;
2601 int op_offload_min_batch_size;
2602};
2603
2604static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
2605 ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2606 return ctx->name.c_str();
2607}
2608
2609static const char * ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
2610 ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2611 return ctx->description.c_str();
2612}
2613
2614static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
2615 ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2616 ggml_backend_cann_get_device_memory(ctx->device, free, total);
2617}
2618
2619static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
2620 GGML_UNUSED(dev);
2621 return GGML_BACKEND_DEVICE_TYPE_GPU;
2622}
2623
2624static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
2625 props->name = ggml_backend_cann_device_get_name(dev);
2626 props->description = ggml_backend_cann_device_get_description(dev);
2627 props->type = ggml_backend_cann_device_get_type(dev);
2628 ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
2629
2630 bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
2631
2632 props->caps = {
2633 /* .async = */ false,
2634 /* .host_buffer = */ host_buffer,
2635 /* .buffer_from_host_ptr = */ false,
2636 /* .events = */ true,
2637 };
2638}
2639
2640static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
2641 GGML_UNUSED(params);
2642 ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2643 return ggml_backend_cann_init(ctx->device);
2644}
2645
2646/**
2647 * @brief Checks if the CANN backend supports a specific backend buffer type.
2648 *
2649 * This function determines whether the CANN backend supports the given backend
2650 * buffer type by comparing the device context of the backend and buffer type.
2651 * It returns true if the devices are same between the backend context and
2652 * buffer type context.
2653 *
2654 * @param backend Pointer to the CANN backend.
2655 * @param buft Pointer to the backend buffer type to check.
2656 * @return bool Returns true if the CANN backend supports the buffer type,
2657 * otherwise false.
2658 */
2659static bool ggml_backend_cann_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
2660 if (ggml_backend_buft_is_cann(buft)) {
2661 ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2662 ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
2663 return buft_ctx->device == dev_ctx->device;
2664 }
2665 return false;
2666}
2667
2668static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
2669 ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *) dev->context;
2670 return ggml_backend_cann_buffer_type(ctx->device);
2671}
2672
2673static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
2674 GGML_UNUSED(dev);
2675 return ggml_backend_cann_host_buffer_type();
2676}
2677
2678/**
2679 * @brief Determines if a tensor operation should be offloaded to the CANN
2680 * backend.
2681 *
2682 * This function checks if a given tensor operation should be offloaded to the
2683 * CANN backend based on the operation type and the size of the tensor. It
2684 * returns true if the second dimension (ne[1]) of the tensor is greater than or
2685 * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
2686 *
2687 * @param backend Pointer to the CANN backend.
2688 * @param op Pointer to the tensor operation to check.
2689 * @return bool Returns true if the operation should be offloaded, otherwise
2690 * false.
2691 */
2692static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
2693 ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
2694
2695 return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS;
2696}
2697
2698/**
2699 * @brief Creates a new event for the CANN backend device.
2700 *
2701 * This function initializes a new event for the CANN backend by setting the
2702 * device and creating an ACL runtime event. The created event is then wrapped
2703 * in a ggml_backend_event structure and returned.
2704 *
2705 * @param backend Pointer to the CANN backend.
2706 * @return ggml_backend_event_t Returns a pointer to the new event structure.
2707 */
2708static ggml_backend_event_t ggml_backend_cann_device_event_new(ggml_backend_dev_t dev) {
2709 ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *) dev->context;
2710
2711 ggml_cann_set_device(dev_ctx->device);
2712
2713 aclrtEvent event;
2714 ACL_CHECK(aclrtCreateEvent(&event));
2715
2716 return new ggml_backend_event{
2717 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
2718 /* .context = */ event,
2719 };
2720}
2721
2722/**
2723 * @brief Frees a CANN backend event.
2724 *
2725 * This function destroys the ACL runtime event associated with the given CANN
2726 * backend event and then deletes the event structure itself.
2727 *
2728 * @param event Pointer to the event structure to be freed.
2729 */
2730static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2731 ACL_CHECK(aclrtDestroyEvent((aclrtEvent) event->context));
2732
2733 delete event;
2734 GGML_UNUSED(dev);
2735}
2736
2737/**
2738 * @brief Synchronizes the given event on the CANN backend.
2739 *
2740 * This function waits for the specified event to complete on the ACL runtime.
2741 *
2742 * @param event Pointer to the event structure to be synchronized.
2743 */
2744static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
2745 ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent) event->context));
2746
2747 GGML_UNUSED(dev);
2748}
2749
2750static const ggml_backend_device_i ggml_backend_cann_device_interface = {
2751 /* .get_name = */ ggml_backend_cann_device_get_name,
2752 /* .get_description = */ ggml_backend_cann_device_get_description,
2753 /* .get_memory = */ ggml_backend_cann_device_get_memory,
2754 /* .get_type = */ ggml_backend_cann_device_get_type,
2755 /* .get_props = */ ggml_backend_cann_device_get_props,
2756 /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2757 /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2758 /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2759 /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2760 /* .supports_op = */ ggml_backend_cann_supports_op,
2761 /* .supports_buft = */ ggml_backend_cann_supports_buft,
2762 /* .offload_op = */ ggml_backend_cann_offload_op,
2763 /* .event_new = */ ggml_backend_cann_device_event_new,
2764 /* .event_free = */ ggml_backend_cann_device_event_free,
2765 /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2766};
2767
2768// backend reg
2769struct ggml_backend_cann_reg_context {
2770 std::vector<ggml_backend_dev_t> devices;
2771};
2772
2773static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2774 GGML_UNUSED(reg);
2775 return GGML_CANN_NAME;
2776}
2777
2778static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2779 ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2780 return ctx->devices.size();
2781}
2782
2783static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2784 ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *) reg->context;
2785 GGML_ASSERT(index < ctx->devices.size());
2786 return ctx->devices[index];
2787}
2788
2789static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2790 GGML_UNUSED(reg);
2791 GGML_UNUSED(name);
2792 // reserved for future use
2793 return nullptr;
2794}
2795
2796static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2797 /* .get_name = */ ggml_backend_cann_reg_get_name,
2798 /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
2799 /* .get_device = */ ggml_backend_cann_reg_get_device,
2800 /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
2801};
2802
2803// backend registry, called only once for cann backend
2804ggml_backend_reg_t ggml_backend_cann_reg() {
2805 static ggml_backend_reg reg;
2806 static bool initialized = false;
2807
2808 {
2809 static std::mutex mutex;
2810 std::lock_guard<std::mutex> lock(mutex);
2811 if (!initialized) {
2812 aclInit(nullptr);
2813 ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2814 const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
2815
2816 for (int i = 0; i < ggml_cann_info().device_count; i++) {
2817 ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context();
2818 dev_ctx->description = aclrtGetSocName();
2819 dev_ctx->device = i;
2820 dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2821 dev_ctx->op_offload_min_batch_size = min_batch_size;
2822 ggml_cann_set_device(i);
2823 ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface,
2824 /* .reg = */ &reg,
2825 /* .context = */ dev_ctx };
2826 ctx->devices.push_back(dev);
2827 }
2828
2829 reg = ggml_backend_reg{ /* .api_version = */ GGML_BACKEND_API_VERSION,
2830 /* .iface = */ ggml_backend_cann_reg_interface,
2831 /* .context = */ ctx };
2832 }
2833
2834 initialized = true;
2835 }
2836
2837 return &reg;
2838}
2839
2840ggml_backend_t ggml_backend_cann_init(int32_t device) {
2841 aclInit(nullptr);
2842 if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
2843 GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
2844 return nullptr;
2845 }
2846
2847 ggml_backend_cann_context * ctx = new ggml_backend_cann_context(device);
2848 if (ctx == nullptr) {
2849 GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2850 return nullptr;
2851 }
2852 ggml_cann_set_device(ctx->device);
2853 ggml_backend_t cann_backend =
2854 new ggml_backend{ /* .guid = */ ggml_backend_cann_guid(),
2855 /* .interface = */ ggml_backend_cann_interface,
2856 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2857 /* .context = */ ctx };
2858
2859 return cann_backend;
2860}
2861
2862bool ggml_backend_is_cann(ggml_backend_t backend) {
2863 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2864}
2865
2866int32_t ggml_backend_cann_get_device_count() {
2867 return ggml_cann_info().device_count;
2868}
2869
2870void ggml_backend_cann_get_device_description(int32_t device, char * description, size_t description_size) {
2871 ggml_cann_set_device(device);
2872 const char * soc_name = aclrtGetSocName();
2873 snprintf(description, description_size, "%s", soc_name);
2874}
2875
2876void ggml_backend_cann_get_device_memory(int32_t device, size_t * free, size_t * total) {
2877 ggml_cann_set_device(device);
2878 ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2879}
2880
2881GGML_BACKEND_DL_IMPL(ggml_backend_cann_reg)