1#include "amx.h"
  2#include "common.h"
  3#include "mmq.h"
  4#include "ggml-backend-impl.h"
  5#include "ggml-backend.h"
  6#include "ggml-impl.h"
  7#include "ggml-cpu.h"
  8#include "traits.h"
  9
 10#if defined(__linux__)
 11#include <sys/syscall.h>
 12#include <unistd.h>
 13#endif
 14
 15#include <cstdlib>
 16#include <cstring>
 17#include <memory>
 18
 19#if defined(__AMX_INT8__) && defined(__AVX512VNNI__)
 20
 21// AMX type_trais
 22namespace ggml::cpu::amx {
 23class tensor_traits : public ggml::cpu::tensor_traits {
 24    bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
 25        size = ggml_backend_amx_desired_wsize(op);
 26        return true;
 27    }
 28
 29    bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
 30        if (op->op == GGML_OP_MUL_MAT) {
 31            ggml_backend_amx_mul_mat(params, op);
 32            return true;
 33        }
 34        return false;
 35    }
 36};
 37
 38static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struct ggml_tensor *) {
 39    static tensor_traits traits;
 40    return &traits;
 41}
 42}  // namespace ggml::cpu::amx
 43
 44// AMX buffer interface
 45static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
 46    free(buffer->context);
 47}
 48
 49static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
 50    return (void *) (buffer->context);
 51}
 52
 53static enum ggml_status ggml_backend_amx_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
 54    tensor->extra = (void *) ggml::cpu::amx::get_tensor_traits(buffer, tensor);
 55
 56    GGML_UNUSED(buffer);
 57    return GGML_STATUS_SUCCESS;
 58}
 59
 60static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
 61                                                  uint8_t value, size_t offset, size_t size) {
 62    memset((char *) tensor->data + offset, value, size);
 63
 64    GGML_UNUSED(buffer);
 65}
 66
 67static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
 68                                               const void * data, size_t offset, size_t size) {
 69    if (qtype_has_amx_kernels(tensor->type)) {
 70        GGML_LOG_DEBUG("%s: amx repack tensor %s of type %s\n", __func__, tensor->name, ggml_type_name(tensor->type));
 71        ggml_backend_amx_convert_weight(tensor, data, offset, size);
 72    } else {
 73        memcpy((char *) tensor->data + offset, data, size);
 74    }
 75
 76    GGML_UNUSED(buffer);
 77}
 78
 79/*
 80// need to figure what we need to do with buffer->extra.
 81static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
 82    GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
 83    memcpy(data, (const char *)tensor->data + offset, size);
 84
 85    GGML_UNUSED(buffer);
 86}
 87
 88static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
 89    if (ggml_backend_buffer_is_host(src->buffer)) {
 90        if (qtype_has_amx_kernels(src->type)) {
 91            ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_nbytes(dst));
 92        } else {
 93            memcpy(dst->data, src->data, ggml_nbytes(src));
 94        }
 95        return true;
 96    }
 97    return false;
 98
 99    GGML_UNUSED(buffer);
100}
101*/
102
103static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
104    memset(buffer->context, value, buffer->size);
105}
106
107static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
108    /* .free_buffer     = */ ggml_backend_amx_buffer_free_buffer,
109    /* .get_base        = */ ggml_backend_amx_buffer_get_base,
110    /* .init_tensor     = */ ggml_backend_amx_buffer_init_tensor,
111    /* .memset_tensor   = */ ggml_backend_amx_buffer_memset_tensor,
112    /* .set_tensor      = */ ggml_backend_amx_buffer_set_tensor,
113    /* .get_tensor      = */ nullptr,
114    /* .cpy_tensor      = */ nullptr,
115    /* .clear           = */ ggml_backend_amx_buffer_clear,
116    /* .reset           = */ nullptr,
117};
118
119static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
120    return "AMX";
121
122    GGML_UNUSED(buft);
123}
124
125static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
126    void * data = ggml_aligned_malloc(size);
127    if (data == NULL) {
128        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
129        return NULL;
130    }
131
132    return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
133}
134
135static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
136    return TENSOR_ALIGNMENT;
137
138    GGML_UNUSED(buft);
139}
140
141namespace ggml::cpu::amx {
142class extra_buffer_type : ggml::cpu::extra_buffer_type {
143    bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
144        // handle only 2d gemm for now
145        auto is_contiguous_2d = [](const struct ggml_tensor * t) {
146            return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
147        };
148
149        if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) &&  // src0 must be contiguous
150            is_contiguous_2d(op->src[1]) &&                               // src1 must be contiguous
151            op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() &&
152            op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315)
153            op->ne[0] % (TILE_N * 2) == 0 &&                              // out_features is 32x
154            (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) {
155            // src1 must be host buffer
156            if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
157                return false;
158            }
159            // src1 must be float32
160            if (op->src[1]->type == GGML_TYPE_F32) {
161                return true;
162            }
163        }
164        return false;
165    }
166
167    ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
168        if (op->op == GGML_OP_MUL_MAT && op->src[0]->buffer &&
169            op->src[0]->buffer->buft == ggml_backend_amx_buffer_type()) {
170            return (ggml::cpu::tensor_traits *) op->src[0]->extra;
171        }
172
173        return nullptr;
174    }
175};
176}  // namespace ggml::cpu::amx
177
178static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
179    return ggml_backend_amx_get_alloc_size(tensor);
180
181    GGML_UNUSED(buft);
182}
183
184#define ARCH_GET_XCOMP_PERM     0x1022
185#define ARCH_REQ_XCOMP_PERM     0x1023
186#define XFEATURE_XTILECFG       17
187#define XFEATURE_XTILEDATA      18
188
189static bool ggml_amx_init() {
190#if defined(__linux__)
191    if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
192        fprintf(stderr, "AMX is not ready to be used!\n");
193        return false;
194    }
195    return true;
196#elif defined(_WIN32)
197    return true;
198#else
199    return false;
200#endif
201}
202
203ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
204    static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
205        /* .iface = */ {
206                        /* .get_name         = */ ggml_backend_amx_buffer_type_get_name,
207                        /* .alloc_buffer     = */ ggml_backend_amx_buffer_type_alloc_buffer,
208                        /* .get_alignment    = */ ggml_backend_amx_buffer_type_get_alignment,
209                        /* .get_max_size     = */ nullptr,  // defaults to SIZE_MAX
210                        /* .get_alloc_size   = */ ggml_backend_amx_buffer_type_get_alloc_size,
211                        /* .is_host          = */ nullptr,
212                        },
213        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
214        /* .context = */ new ggml::cpu::amx::extra_buffer_type(),
215    };
216
217    if (!ggml_amx_init()) {
218        return nullptr;
219    }
220
221    return &ggml_backend_buffer_type_amx;
222}
223
224#endif  // defined(__AMX_INT8__) && defined(__AVX512VNNI__)