1#include "ggml-zendnn.h"
2
3#include "ggml-backend-impl.h"
4#include "ggml-impl.h"
5#include "zendnnl.hpp"
6
7#include <cstring>
8
9
10struct ggml_backend_zendnn_context {
11 int n_threads = GGML_DEFAULT_N_THREADS;
12 std::unique_ptr<char[]> work_data;
13 size_t work_size = 0;
14};
15
16template<typename T>
17zendnnl::common::data_type_t ggml_to_zendnn_type() {
18 if constexpr (std::is_same_v<T, float>) {
19 return zendnnl::common::data_type_t::f32;
20 } else if constexpr (std::is_same_v<T, ggml_bf16_t>) {
21 return zendnnl::common::data_type_t::bf16;
22 } else {
23 return zendnnl::common::data_type_t::none;
24 }
25}
26
27/**
28 * ZenDNN matmul: computes C = B * A.
29 *
30 * - A: weights, shape (k, m), column-major (each column is a weight vector for one output).
31 * - B: input, shape (n, k), row-major (each row is an input sample).
32 * - C: output, shape (n, m), row-major.
33 *
34 * Dimensions:
35 * m = output features (columns of C, columns of A)
36 * n = batch size (rows of C, rows of B)
37 * k = inner dimension (columns of B, rows of A)
38 */
39template <typename TA, typename TB, typename TC>
40static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
41 const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,
42 int64_t ldc) {
43
44 zendnnl::lowoha::lowoha_params params;
45 params.dtypes.src = ggml_to_zendnn_type<TB>();
46 params.dtypes.wei = ggml_to_zendnn_type<TA>();
47 params.dtypes.dst = ggml_to_zendnn_type<TC>();
48 params.num_threads = ctx->n_threads;
49
50 zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct(
51 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major)
52 n, // M: rows of B and C
53 m, // N: cols of A^T and C
54 k, // K: cols of B, rows of A
55 1.0f, // alpha
56 B, ldb, // src: B[n,k]
57 A, lda, // weight: A[k,m] column-major (transposed)
58 nullptr, // bias
59 0.0f, // beta
60 C, ldc, // output C[n,m]
61 true, // is_weights_const
62 {}, // batch_params
63 params // params
64 );
65
66 if (status != zendnnl::lowoha::status_t::success) {
67 GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast<int>(status));
68 return false;
69 }
70 return true;
71}
72
73static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
74 const void * A, int64_t lda, const void * B, int64_t ldb, void * C,
75 int64_t ldc, int Atype, int Btype, int Ctype) {
76
77 assert(m >= 0);
78 assert(n >= 0);
79 assert(k >= 0);
80 assert(lda >= k);
81 assert(ldb >= k);
82 assert(ldc >= m);
83
84 // categorize types
85 switch (Atype) {
86 case GGML_TYPE_F32:
87 if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)
88 return false;
89 return ggml_zendnn_matmul<float, float, float>(
90 ctx, m, n, k,
91 (const float *)A, lda,
92 (const float *)B, ldb,
93 (float *)C, ldc);
94 case GGML_TYPE_BF16:
95 if (Btype != GGML_TYPE_BF16)
96 return false;
97 if (Ctype == GGML_TYPE_BF16)
98 return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(
99 ctx, m, n, k,
100 (const ggml_bf16_t *)A, lda,
101 (const ggml_bf16_t *)B, ldb,
102 (ggml_bf16_t *)C, ldc);
103 if (Ctype == GGML_TYPE_F32)
104 return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, float>(
105 ctx, m, n, k,
106 (const ggml_bf16_t *)A, lda,
107 (const ggml_bf16_t *)B, ldb,
108 (float *)C, ldc);
109 return false;
110 default:
111 return false; // unsupported type
112 }
113}
114
115static void ggml_zendnn_compute_forward_mul_mat(
116 ggml_backend_zendnn_context * ctx,
117 ggml_tensor * dst) {
118
119 const ggml_tensor * src0 = dst->src[0]; // weights
120 const ggml_tensor * src1 = dst->src[1]; // inputs
121
122 GGML_TENSOR_BINARY_OP_LOCALS
123
124 ggml_type const vec_dot_type = src0->type;
125 ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
126
127 GGML_ASSERT(ne0 == ne01);
128 GGML_ASSERT(ne1 == ne11);
129 GGML_ASSERT(ne2 == ne12);
130 GGML_ASSERT(ne3 == ne13);
131
132 // we don't support permuted src0 or src1
133 GGML_ASSERT(nb00 == ggml_type_size(src0->type));
134 GGML_ASSERT(nb10 == ggml_type_size(src1->type));
135
136 // dst cannot be transposed or permuted
137 GGML_ASSERT(nb0 == sizeof(float));
138 GGML_ASSERT(nb0 <= nb1);
139 GGML_ASSERT(nb1 <= nb2);
140 GGML_ASSERT(nb2 <= nb3);
141
142 // broadcast factors
143 const int64_t r2 = ne12/ne02;
144 const int64_t r3 = ne13/ne03;
145
146 void * work_data = ctx->work_data.get();
147 if (src1->type != vec_dot_type) {
148 const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
149 const size_t nbw2 = nbw1 * ne11;
150 const size_t nbw3 = nbw2 * ne12;
151 const size_t desired_wsize = ne13 * nbw3;
152 if (ctx->work_size < desired_wsize) {
153 ctx->work_data.reset(new char[desired_wsize]);
154 ctx->work_size = desired_wsize;
155 }
156 work_data = ctx->work_data.get();
157
158 // #pragma omp parallel for num_threads(ctx->n_threads)
159 #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
160 for (int64_t i13 = 0; i13 < ne13; ++i13) {
161 for (int64_t i12 = 0; i12 < ne12; ++i12) {
162 for (int64_t i11 = 0; i11 < ne11; ++i11) {
163 const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
164 void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
165 from_float(src1_f32, src1_conv, ne10);
166 }
167 }
168 }
169 }
170
171 for (int64_t i13 = 0; i13 < ne13; i13++) {
172 for (int64_t i12 = 0; i12 < ne12; i12++) {
173 const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;
174 const size_t row_size = ggml_row_size(vec_dot_type, ne10);
175 if (!ggml_zendnn_sgemm(ctx,
176 ne01, // m
177 ne11, // n
178 ne10, // k
179 static_cast<const char *>(src0->data) + (i12/r2)*nb02 + (i13/r3)*nb03,
180 ne00, // lda
181 static_cast<const char *>(wdata) + (i12*ne11 + i13*ne12*ne11)*row_size,
182 ne10, // ldb
183 static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,
184 ne01, // ldc
185 src0->type,
186 vec_dot_type,
187 dst->type))
188 GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
189 }
190 }
191}
192
193// backend interface
194
195static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
196 return "ZenDNN";
197
198 GGML_UNUSED(backend);
199}
200
201static void ggml_backend_zendnn_free(ggml_backend_t backend) {
202 ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
203 delete ctx;
204 delete backend;
205}
206
207static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
208 ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
209
210 for (int i = 0; i < cgraph->n_nodes; i++) {
211 struct ggml_tensor * node = cgraph->nodes[i];
212
213 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
214 continue;
215 }
216
217 switch (node->op) {
218 case GGML_OP_MUL_MAT:
219 ggml_zendnn_compute_forward_mul_mat(ctx, node);
220 break;
221 case GGML_OP_NONE:
222 case GGML_OP_RESHAPE:
223 case GGML_OP_VIEW:
224 case GGML_OP_PERMUTE:
225 case GGML_OP_TRANSPOSE:
226 break;
227
228 default:
229 GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
230 }
231 }
232
233 return GGML_STATUS_SUCCESS;
234
235 GGML_UNUSED(backend);
236}
237
238static struct ggml_backend_i ggml_backend_zendnn_i = {
239 /* .get_name = */ ggml_backend_zendnn_get_name,
240 /* .free = */ ggml_backend_zendnn_free,
241 /* .set_tensor_async = */ NULL,
242 /* .get_tensor_async = */ NULL,
243 /* .cpy_tensor_async = */ NULL,
244 /* .synchronize = */ NULL,
245 /* .graph_plan_create = */ NULL,
246 /* .graph_plan_free = */ NULL,
247 /* .graph_plan_update = */ NULL,
248 /* .graph_plan_compute = */ NULL,
249 /* .graph_compute = */ ggml_backend_zendnn_graph_compute,
250 /* .event_record = */ NULL,
251 /* .event_wait = */ NULL,
252 /* .graph_optimize = */ NULL,
253};
254
255static ggml_guid_t ggml_backend_zendnn_guid(void) {
256 static const char * guid_str = "AMD-ZENDNN-ACCEL";
257 return reinterpret_cast<ggml_guid_t>(const_cast<char*>(guid_str));
258}
259
260ggml_backend_t ggml_backend_zendnn_init(void) {
261 ggml_backend_zendnn_context * ctx = new ggml_backend_zendnn_context;
262
263 ggml_backend_t backend = new ggml_backend {
264 /* .guid = */ ggml_backend_zendnn_guid(),
265 /* .iface = */ ggml_backend_zendnn_i,
266 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_zendnn_reg(), 0),
267 /* .context = */ ctx,
268 };
269
270 return backend;
271}
272
273bool ggml_backend_is_zendnn(ggml_backend_t backend) {
274 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_zendnn_guid());
275}
276
277void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads) {
278 GGML_ASSERT(ggml_backend_is_zendnn(backend_zendnn));
279
280 ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend_zendnn->context;
281 ctx->n_threads = n_threads;
282}
283
284// device interface
285static const char * ggml_backend_zendnn_device_get_name(ggml_backend_dev_t dev) {
286 return "ZenDNN";
287
288 GGML_UNUSED(dev);
289}
290/**
291 * ZenDNN is AMD's performance library providing optimized primitives and implementations
292 * for deep learning workloads on AMD CPUs. It targets improved performance for common
293 * neural network operations on AMD architectures. For more information, see:
294 * https://www.amd.com/en/developer/zendnn.html
295 */
296static const char * ggml_backend_zendnn_device_get_description(ggml_backend_dev_t dev) {
297 return "ZenDNN: AMD optimized primitives backend for GGML (optimized for AMD CPUs)";
298
299 GGML_UNUSED(dev);
300}
301
302static void ggml_backend_zendnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
303 *free = 0;
304 *total = 0;
305
306 GGML_UNUSED(dev);
307}
308
309static enum ggml_backend_dev_type ggml_backend_zendnn_device_get_type(ggml_backend_dev_t dev) {
310 return GGML_BACKEND_DEVICE_TYPE_ACCEL;
311
312 GGML_UNUSED(dev);
313}
314
315static void ggml_backend_zendnn_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
316 props->name = ggml_backend_zendnn_device_get_name(dev);
317 props->description = ggml_backend_zendnn_device_get_description(dev);
318 props->type = ggml_backend_zendnn_device_get_type(dev);
319 ggml_backend_zendnn_device_get_memory(dev, &props->memory_free, &props->memory_total);
320 props->caps = {
321 /* .async = */ false,
322 /* .host_buffer = */ false,
323 /* .buffer_from_host_ptr = */ true,
324 /* .events = */ false
325 };
326}
327
328static ggml_backend_t ggml_backend_zendnn_device_init_backend(ggml_backend_dev_t dev, const char * params) {
329 ggml_backend_t backend = ggml_backend_zendnn_init();
330 if (backend == NULL) {
331 GGML_LOG_ERROR("%s: error: failed to initialize ZenDNN backend\n", __func__);
332 return NULL;
333 }
334
335 return backend;
336
337 GGML_UNUSED(dev);
338 GGML_UNUSED(params);
339}
340
341static ggml_backend_buffer_type_t ggml_backend_zendnn_device_get_buffer_type(ggml_backend_dev_t dev) {
342 return ggml_backend_cpu_buffer_type();
343
344 GGML_UNUSED(dev);
345}
346
347static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
348 return ggml_backend_cpu_buffer_from_ptr(ptr, size);
349
350 GGML_UNUSED(dev);
351 GGML_UNUSED(max_tensor_size);
352}
353
354static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
355 switch (op->op) {
356 case GGML_OP_NONE:
357 case GGML_OP_RESHAPE:
358 case GGML_OP_VIEW:
359 case GGML_OP_PERMUTE:
360 case GGML_OP_TRANSPOSE:
361 return true;
362
363 case GGML_OP_MUL_MAT:
364 {
365 const ggml_tensor * weights = op->src[0];
366 const ggml_tensor * inputs = op->src[1];
367
368 const int64_t ne10 = inputs->ne[0];
369 const int64_t ne0 = op->ne[0];
370 const int64_t ne1 = op->ne[1];
371
372 const int64_t min_batch = 1;
373 if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
374 ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
375 return false;
376 }
377 switch (weights->type) {
378 case GGML_TYPE_F32:
379 case GGML_TYPE_BF16:
380 return true;
381 default:
382 return false;
383 }
384 } break;
385
386 default:
387 return false;
388 }
389
390 GGML_UNUSED(dev);
391}
392
393static bool ggml_backend_zendnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
394 return ggml_backend_buft_is_host(buft);
395
396 GGML_UNUSED(dev);
397}
398
399static const struct ggml_backend_device_i ggml_backend_zendnn_device_i = {
400 /* .get_name = */ ggml_backend_zendnn_device_get_name,
401 /* .get_description = */ ggml_backend_zendnn_device_get_description,
402 /* .get_memory = */ ggml_backend_zendnn_device_get_memory,
403 /* .get_type = */ ggml_backend_zendnn_device_get_type,
404 /* .get_props = */ ggml_backend_zendnn_device_get_props,
405 /* .init_backend = */ ggml_backend_zendnn_device_init_backend,
406 /* .get_buffer_type = */ ggml_backend_zendnn_device_get_buffer_type,
407 /* .get_host_buffer_type = */ NULL,
408 /* .buffer_from_host_ptr = */ ggml_backend_zendnn_device_buffer_from_host_ptr,
409 /* .supports_op = */ ggml_backend_zendnn_device_supports_op,
410 /* .supports_buft = */ ggml_backend_zendnn_device_supports_buft,
411 /* .offload_op = */ NULL,
412 /* .event_new = */ NULL,
413 /* .event_free = */ NULL,
414 /* .event_synchronize = */ NULL,
415};
416
417// backend reg interface
418static const char * ggml_backend_zendnn_reg_get_name(ggml_backend_reg_t reg) {
419 return "ZenDNN";
420
421 GGML_UNUSED(reg);
422}
423
424static size_t ggml_backend_zendnn_reg_get_device_count(ggml_backend_reg_t reg) {
425 return 1;
426
427 GGML_UNUSED(reg);
428}
429
430static ggml_backend_dev_t ggml_backend_zendnn_reg_get_device(ggml_backend_reg_t reg, size_t index) {
431 GGML_ASSERT(index == 0);
432
433 static ggml_backend_device ggml_backend_zendnn_device = {
434 /* .iface = */ ggml_backend_zendnn_device_i,
435 /* .reg = */ reg,
436 /* .context = */ nullptr,
437 };
438
439 return &ggml_backend_zendnn_device;
440}
441
442static void * ggml_backend_zendnn_get_proc_address(ggml_backend_reg_t reg, const char * name) {
443 if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
444 return (void *) ggml_backend_zendnn_set_n_threads;
445 }
446 return NULL;
447
448 GGML_UNUSED(reg);
449 GGML_UNUSED(name);
450}
451
452static const struct ggml_backend_reg_i ggml_backend_zendnn_reg_i = {
453 /* .get_name = */ ggml_backend_zendnn_reg_get_name,
454 /* .get_device_count = */ ggml_backend_zendnn_reg_get_device_count,
455 /* .get_device = */ ggml_backend_zendnn_reg_get_device,
456 /* .get_proc_address = */ ggml_backend_zendnn_get_proc_address,
457};
458
459ggml_backend_reg_t ggml_backend_zendnn_reg(void) {
460 static struct ggml_backend_reg ggml_backend_zendnn_reg = {
461 /* .api_version = */ GGML_BACKEND_API_VERSION,
462 /* .iface = */ ggml_backend_zendnn_reg_i,
463 /* .context = */ NULL,
464 };
465
466 return &ggml_backend_zendnn_reg;
467}
468
469GGML_BACKEND_DL_IMPL(ggml_backend_zendnn_reg)