1#include "ggml-zendnn.h"
  2
  3#include "ggml-backend-impl.h"
  4#include "ggml-impl.h"
  5#include "zendnnl.hpp"
  6
  7#include <cstring>
  8
  9
 10struct ggml_backend_zendnn_context {
 11    int n_threads = GGML_DEFAULT_N_THREADS;
 12    std::unique_ptr<char[]> work_data;
 13    size_t work_size = 0;
 14};
 15
 16template<typename T>
 17zendnnl::common::data_type_t ggml_to_zendnn_type() {
 18    if constexpr (std::is_same_v<T, float>) {
 19        return zendnnl::common::data_type_t::f32;
 20    } else if constexpr (std::is_same_v<T, ggml_bf16_t>) {
 21        return zendnnl::common::data_type_t::bf16;
 22    } else {
 23        return zendnnl::common::data_type_t::none;
 24    }
 25}
 26
 27/**
 28 * ZenDNN matmul: computes C = B * A.
 29 *
 30 * - A: weights, shape (k, m), column-major (each column is a weight vector for one output).
 31 * - B: input, shape (n, k), row-major (each row is an input sample).
 32 * - C: output, shape (n, m), row-major.
 33 *
 34 * Dimensions:
 35 *   m = output features (columns of C, columns of A)
 36 *   n = batch size      (rows of C, rows of B)
 37 *   k = inner dimension (columns of B, rows of A)
 38 */
 39template <typename TA, typename TB, typename TC>
 40static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
 41                               const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C,
 42                               int64_t ldc) {
 43
 44    zendnnl::lowoha::lowoha_params params;
 45    params.dtypes.src = ggml_to_zendnn_type<TB>();
 46    params.dtypes.wei = ggml_to_zendnn_type<TA>();
 47    params.dtypes.dst = ggml_to_zendnn_type<TC>();
 48    params.num_threads = ctx->n_threads;
 49
 50    zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct(
 51        'r', false, true,   // row-major, don't transpose B, transpose A (because it's column-major)
 52        n,                  // M: rows of B and C
 53        m,                  // N: cols of A^T and C
 54        k,                  // K: cols of B, rows of A
 55        1.0f,               // alpha
 56        B, ldb,             // src: B[n,k]
 57        A, lda,             // weight: A[k,m] column-major (transposed)
 58        nullptr,            // bias
 59        0.0f,               // beta
 60        C, ldc,             // output C[n,m]
 61        true,               // is_weights_const
 62        {},                 // batch_params
 63        params              // params
 64    );
 65
 66    if (status != zendnnl::lowoha::status_t::success) {
 67        GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast<int>(status));
 68        return false;
 69    }
 70    return true;
 71}
 72
 73static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k,
 74                              const void * A, int64_t lda, const void * B, int64_t ldb, void * C,
 75                              int64_t ldc, int Atype, int Btype, int Ctype) {
 76
 77    assert(m >= 0);
 78    assert(n >= 0);
 79    assert(k >= 0);
 80    assert(lda >= k);
 81    assert(ldb >= k);
 82    assert(ldc >= m);
 83
 84    // categorize types
 85    switch (Atype) {
 86        case GGML_TYPE_F32:
 87            if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32)
 88                return false;
 89            return ggml_zendnn_matmul<float, float, float>(
 90                ctx, m, n, k,
 91                (const float *)A, lda,
 92                (const float *)B, ldb,
 93                (float *)C, ldc);
 94        case GGML_TYPE_BF16:
 95            if (Btype != GGML_TYPE_BF16)
 96                return false;
 97            if (Ctype == GGML_TYPE_BF16)
 98                return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, ggml_bf16_t>(
 99                    ctx, m, n, k,
100                    (const ggml_bf16_t *)A, lda,
101                    (const ggml_bf16_t *)B, ldb,
102                    (ggml_bf16_t *)C, ldc);
103            if (Ctype == GGML_TYPE_F32)
104                return ggml_zendnn_matmul<ggml_bf16_t, ggml_bf16_t, float>(
105                    ctx, m, n, k,
106                    (const ggml_bf16_t *)A, lda,
107                    (const ggml_bf16_t *)B, ldb,
108                    (float *)C, ldc);
109            return false;
110        default:
111            return false; // unsupported type
112    }
113}
114
115static void ggml_zendnn_compute_forward_mul_mat(
116    ggml_backend_zendnn_context * ctx,
117    ggml_tensor * dst) {
118
119    const ggml_tensor * src0 = dst->src[0];  // weights
120    const ggml_tensor * src1 = dst->src[1];  // inputs
121
122    GGML_TENSOR_BINARY_OP_LOCALS
123
124    ggml_type         const vec_dot_type = src0->type;
125    ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref;
126
127    GGML_ASSERT(ne0 == ne01);
128    GGML_ASSERT(ne1 == ne11);
129    GGML_ASSERT(ne2 == ne12);
130    GGML_ASSERT(ne3 == ne13);
131
132    // we don't support permuted src0 or src1
133    GGML_ASSERT(nb00 == ggml_type_size(src0->type));
134    GGML_ASSERT(nb10 == ggml_type_size(src1->type));
135
136    // dst cannot be transposed or permuted
137    GGML_ASSERT(nb0 == sizeof(float));
138    GGML_ASSERT(nb0 <= nb1);
139    GGML_ASSERT(nb1 <= nb2);
140    GGML_ASSERT(nb2 <= nb3);
141
142    // broadcast factors
143    const int64_t r2 = ne12/ne02;
144    const int64_t r3 = ne13/ne03;
145
146    void * work_data = ctx->work_data.get();
147    if (src1->type != vec_dot_type) {
148        const size_t nbw1 = ggml_row_size(vec_dot_type, ne10);
149        const size_t nbw2 = nbw1 * ne11;
150        const size_t nbw3 = nbw2 * ne12;
151        const size_t desired_wsize = ne13 * nbw3;
152        if (ctx->work_size < desired_wsize) {
153            ctx->work_data.reset(new char[desired_wsize]);
154            ctx->work_size = desired_wsize;
155        }
156        work_data = ctx->work_data.get();
157
158        // #pragma omp parallel for num_threads(ctx->n_threads)
159        #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static)
160        for (int64_t i13 = 0; i13 < ne13; ++i13) {
161            for (int64_t i12 = 0; i12 < ne12; ++i12) {
162                for (int64_t i11 = 0; i11 < ne11; ++i11) {
163                    const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13);
164                    void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3;
165                    from_float(src1_f32, src1_conv, ne10);
166                }
167            }
168        }
169    }
170
171    for (int64_t i13 = 0; i13 < ne13; i13++) {
172        for (int64_t i12 = 0; i12 < ne12; i12++) {
173            const void* wdata = src1->type == vec_dot_type ? src1->data : work_data;
174            const size_t row_size = ggml_row_size(vec_dot_type, ne10);
175            if (!ggml_zendnn_sgemm(ctx,
176                                  ne01,     // m
177                                  ne11,     // n
178                                  ne10,     // k
179                                  static_cast<const char *>(src0->data) + (i12/r2)*nb02 + (i13/r3)*nb03,
180                                  ne00,     // lda
181                                  static_cast<const char *>(wdata) + (i12*ne11 + i13*ne12*ne11)*row_size,
182                                  ne10,     // ldb
183                                  static_cast<char *>(dst->data) + i12*nb2 + i13*nb3,
184                                  ne01,     // ldc
185                                  src0->type,
186                                  vec_dot_type,
187                                  dst->type))
188                GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__);
189        }
190    }
191}
192
193// backend interface
194
195static const char * ggml_backend_zendnn_get_name(ggml_backend_t backend) {
196    return "ZenDNN";
197
198    GGML_UNUSED(backend);
199}
200
201static void ggml_backend_zendnn_free(ggml_backend_t backend) {
202    ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
203    delete ctx;
204    delete backend;
205}
206
207static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
208    ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend->context;
209
210    for (int i = 0; i < cgraph->n_nodes; i++) {
211        struct ggml_tensor * node = cgraph->nodes[i];
212
213        if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
214            continue;
215        }
216
217        switch (node->op) {
218            case GGML_OP_MUL_MAT:
219                ggml_zendnn_compute_forward_mul_mat(ctx, node);
220                break;
221            case GGML_OP_NONE:
222            case GGML_OP_RESHAPE:
223            case GGML_OP_VIEW:
224            case GGML_OP_PERMUTE:
225            case GGML_OP_TRANSPOSE:
226                break;
227
228            default:
229                GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
230        }
231    }
232
233    return GGML_STATUS_SUCCESS;
234
235    GGML_UNUSED(backend);
236}
237
238static struct ggml_backend_i ggml_backend_zendnn_i = {
239    /* .get_name                = */ ggml_backend_zendnn_get_name,
240    /* .free                    = */ ggml_backend_zendnn_free,
241    /* .set_tensor_async        = */ NULL,
242    /* .get_tensor_async        = */ NULL,
243    /* .cpy_tensor_async        = */ NULL,
244    /* .synchronize             = */ NULL,
245    /* .graph_plan_create       = */ NULL,
246    /* .graph_plan_free         = */ NULL,
247    /* .graph_plan_update       = */ NULL,
248    /* .graph_plan_compute      = */ NULL,
249    /* .graph_compute           = */ ggml_backend_zendnn_graph_compute,
250    /* .event_record            = */ NULL,
251    /* .event_wait              = */ NULL,
252    /* .graph_optimize          = */ NULL,
253};
254
255static ggml_guid_t ggml_backend_zendnn_guid(void) {
256    static const char * guid_str = "AMD-ZENDNN-ACCEL";
257    return reinterpret_cast<ggml_guid_t>(const_cast<char*>(guid_str));
258}
259
260ggml_backend_t ggml_backend_zendnn_init(void) {
261    ggml_backend_zendnn_context * ctx = new ggml_backend_zendnn_context;
262
263    ggml_backend_t backend = new ggml_backend {
264        /* .guid    = */ ggml_backend_zendnn_guid(),
265        /* .iface   = */ ggml_backend_zendnn_i,
266        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_zendnn_reg(), 0),
267        /* .context = */ ctx,
268    };
269
270    return backend;
271}
272
273bool ggml_backend_is_zendnn(ggml_backend_t backend) {
274    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_zendnn_guid());
275}
276
277void ggml_backend_zendnn_set_n_threads(ggml_backend_t backend_zendnn, int n_threads) {
278    GGML_ASSERT(ggml_backend_is_zendnn(backend_zendnn));
279
280    ggml_backend_zendnn_context * ctx = (ggml_backend_zendnn_context *)backend_zendnn->context;
281    ctx->n_threads = n_threads;
282}
283
284// device interface
285static const char * ggml_backend_zendnn_device_get_name(ggml_backend_dev_t dev) {
286    return "ZenDNN";
287
288    GGML_UNUSED(dev);
289}
290/**
291 * ZenDNN is AMD's performance library providing optimized primitives and implementations
292 * for deep learning workloads on AMD CPUs. It targets improved performance for common
293 * neural network operations on AMD architectures. For more information, see:
294 * https://www.amd.com/en/developer/zendnn.html
295 */
296static const char * ggml_backend_zendnn_device_get_description(ggml_backend_dev_t dev) {
297    return "ZenDNN: AMD optimized primitives backend for GGML (optimized for AMD CPUs)";
298
299    GGML_UNUSED(dev);
300}
301
302static void ggml_backend_zendnn_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
303    *free  = 0;
304    *total = 0;
305
306    GGML_UNUSED(dev);
307}
308
309static enum ggml_backend_dev_type ggml_backend_zendnn_device_get_type(ggml_backend_dev_t dev) {
310    return GGML_BACKEND_DEVICE_TYPE_ACCEL;
311
312    GGML_UNUSED(dev);
313}
314
315static void ggml_backend_zendnn_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
316    props->name        = ggml_backend_zendnn_device_get_name(dev);
317    props->description = ggml_backend_zendnn_device_get_description(dev);
318    props->type        = ggml_backend_zendnn_device_get_type(dev);
319    ggml_backend_zendnn_device_get_memory(dev, &props->memory_free, &props->memory_total);
320    props->caps = {
321        /* .async                = */ false,
322        /* .host_buffer          = */ false,
323        /* .buffer_from_host_ptr = */ true,
324        /* .events               = */ false
325    };
326}
327
328static ggml_backend_t ggml_backend_zendnn_device_init_backend(ggml_backend_dev_t dev, const char * params) {
329    ggml_backend_t backend = ggml_backend_zendnn_init();
330    if (backend == NULL) {
331        GGML_LOG_ERROR("%s: error: failed to initialize ZenDNN backend\n", __func__);
332        return NULL;
333    }
334
335    return backend;
336
337    GGML_UNUSED(dev);
338    GGML_UNUSED(params);
339}
340
341static ggml_backend_buffer_type_t ggml_backend_zendnn_device_get_buffer_type(ggml_backend_dev_t dev) {
342    return ggml_backend_cpu_buffer_type();
343
344    GGML_UNUSED(dev);
345}
346
347static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
348    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
349
350    GGML_UNUSED(dev);
351    GGML_UNUSED(max_tensor_size);
352}
353
354static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
355    switch (op->op) {
356        case GGML_OP_NONE:
357        case GGML_OP_RESHAPE:
358        case GGML_OP_VIEW:
359        case GGML_OP_PERMUTE:
360        case GGML_OP_TRANSPOSE:
361            return true;
362
363        case GGML_OP_MUL_MAT:
364        {
365            const ggml_tensor * weights = op->src[0];
366            const ggml_tensor * inputs = op->src[1];
367
368            const int64_t ne10 = inputs->ne[0];
369            const int64_t ne0 = op->ne[0];
370            const int64_t ne1 = op->ne[1];
371
372            const int64_t min_batch = 1;
373            if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) ||
374                ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) {
375                    return false;
376            }
377            switch (weights->type) {
378                case GGML_TYPE_F32:
379                case GGML_TYPE_BF16:
380                    return true;
381                default:
382                    return false;
383            }
384        } break;
385
386        default:
387            return false;
388    }
389
390    GGML_UNUSED(dev);
391}
392
393static bool ggml_backend_zendnn_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
394    return ggml_backend_buft_is_host(buft);
395
396    GGML_UNUSED(dev);
397}
398
399static const struct ggml_backend_device_i ggml_backend_zendnn_device_i = {
400    /* .get_name               = */ ggml_backend_zendnn_device_get_name,
401    /* .get_description        = */ ggml_backend_zendnn_device_get_description,
402    /* .get_memory             = */ ggml_backend_zendnn_device_get_memory,
403    /* .get_type               = */ ggml_backend_zendnn_device_get_type,
404    /* .get_props              = */ ggml_backend_zendnn_device_get_props,
405    /* .init_backend           = */ ggml_backend_zendnn_device_init_backend,
406    /* .get_buffer_type        = */ ggml_backend_zendnn_device_get_buffer_type,
407    /* .get_host_buffer_type   = */ NULL,
408    /* .buffer_from_host_ptr   = */ ggml_backend_zendnn_device_buffer_from_host_ptr,
409    /* .supports_op            = */ ggml_backend_zendnn_device_supports_op,
410    /* .supports_buft          = */ ggml_backend_zendnn_device_supports_buft,
411    /* .offload_op             = */ NULL,
412    /* .event_new              = */ NULL,
413    /* .event_free             = */ NULL,
414    /* .event_synchronize      = */ NULL,
415};
416
417// backend reg interface
418static const char * ggml_backend_zendnn_reg_get_name(ggml_backend_reg_t reg) {
419    return "ZenDNN";
420
421    GGML_UNUSED(reg);
422}
423
424static size_t ggml_backend_zendnn_reg_get_device_count(ggml_backend_reg_t reg) {
425    return 1;
426
427    GGML_UNUSED(reg);
428}
429
430static ggml_backend_dev_t ggml_backend_zendnn_reg_get_device(ggml_backend_reg_t reg, size_t index) {
431    GGML_ASSERT(index == 0);
432
433    static ggml_backend_device ggml_backend_zendnn_device = {
434        /* .iface   = */ ggml_backend_zendnn_device_i,
435        /* .reg     = */ reg,
436        /* .context = */ nullptr,
437    };
438
439    return &ggml_backend_zendnn_device;
440}
441
442static void * ggml_backend_zendnn_get_proc_address(ggml_backend_reg_t reg, const char * name) {
443    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
444        return (void *) ggml_backend_zendnn_set_n_threads;
445    }
446    return NULL;
447
448    GGML_UNUSED(reg);
449    GGML_UNUSED(name);
450}
451
452static const struct ggml_backend_reg_i ggml_backend_zendnn_reg_i = {
453    /* .get_name         = */ ggml_backend_zendnn_reg_get_name,
454    /* .get_device_count = */ ggml_backend_zendnn_reg_get_device_count,
455    /* .get_device       = */ ggml_backend_zendnn_reg_get_device,
456    /* .get_proc_address = */ ggml_backend_zendnn_get_proc_address,
457};
458
459ggml_backend_reg_t ggml_backend_zendnn_reg(void) {
460    static struct ggml_backend_reg ggml_backend_zendnn_reg = {
461        /* .api_version = */ GGML_BACKEND_API_VERSION,
462        /* .iface       = */ ggml_backend_zendnn_reg_i,
463        /* .context     = */ NULL,
464    };
465
466    return &ggml_backend_zendnn_reg;
467}
468
469GGML_BACKEND_DL_IMPL(ggml_backend_zendnn_reg)