1#pragma once
  2
  3#include "ggml.h"
  4
  5#ifdef __cplusplus
  6extern "C" {
  7#endif
  8
  9struct ggml_metal_buffer_id {
 10    void * metal; // id<MTLBuffer>
 11    size_t offs;
 12};
 13
 14typedef struct ggml_metal_device * ggml_metal_device_t;
 15
 16//
 17// MTLFunctionConstantValues wrapper
 18//
 19
 20typedef struct ggml_metal_cv * ggml_metal_cv_t;
 21
 22ggml_metal_cv_t ggml_metal_cv_init(void);
 23void ggml_metal_cv_free(ggml_metal_cv_t cv);
 24
 25void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
 26void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
 27void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool    value, int32_t idx);
 28
 29//
 30// MTLComputePipelineState wrapper
 31//
 32
 33typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
 34
 35ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
 36void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
 37
 38// a collection of pipelines
 39typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
 40
 41ggml_metal_pipelines_t ggml_metal_pipelines_init(void);
 42void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
 43
 44void                  ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
 45ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
 46
 47struct ggml_metal_pipeline_with_params {
 48    ggml_metal_pipeline_t pipeline;
 49
 50    int nsg;
 51
 52    int nr0;
 53    int nr1;
 54
 55    size_t smem;
 56
 57    bool c4;
 58    bool cnt;
 59};
 60
 61int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
 62
 63//
 64// MTLCommandBuffer wrapper
 65//
 66
 67typedef void * ggml_metal_cmd_buf_t;
 68
 69//
 70// MTLComputeCommandEncoder wrapper
 71//
 72
 73typedef struct ggml_metal_encoder * ggml_metal_encoder_t;
 74
 75ggml_metal_encoder_t ggml_metal_encoder_init(ggml_metal_cmd_buf_t cmd_buf_raw, bool concurrent);
 76void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
 77
 78void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
 79void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
 80
 81void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
 82
 83void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
 84void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
 85
 86void ggml_metal_encoder_set_threadgroup_memory_size(ggml_metal_encoder_t encoder, size_t size, int idx);
 87
 88void ggml_metal_encoder_dispatch_threadgroups(ggml_metal_encoder_t encoder, int tg0, int tg1, int tg2, int tptg0, int tptg1, int tptg2);
 89
 90void ggml_metal_encoder_memory_barrier(ggml_metal_encoder_t encoder);
 91
 92void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder);
 93
 94//
 95// MTLLibrary wrapper
 96//
 97
 98typedef struct ggml_metal_library * ggml_metal_library_t;
 99
100ggml_metal_library_t ggml_metal_library_init            (ggml_metal_device_t dev);
101ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev, const char * source, bool verbose);
102
103void ggml_metal_library_free(ggml_metal_library_t lib);
104
105struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline    (ggml_metal_library_t lib, const char * name);
106struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
107
108struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);
109struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
110struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
111struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
112struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
113struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
114struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag              (ggml_metal_library_t lib, const struct ggml_tensor * op);
115struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
116struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
117struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
118struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum               (ggml_metal_library_t lib, const struct ggml_tensor * op);
119struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows          (ggml_metal_library_t lib, const struct ggml_tensor * op);
120struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk        (ggml_metal_library_t lib, const struct ggml_tensor * op);
121struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add        (ggml_metal_library_t lib, const struct ggml_tensor * op);
122struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_tri               (ggml_metal_library_t lib, const struct ggml_tensor * op);
123struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);
124struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
125struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched  (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
126struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
127struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
128struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri         (ggml_metal_library_t lib, const struct ggml_tensor * op);
129struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
130struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
131struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
132struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0    (ggml_metal_library_t lib, int ne02, int ne20);
133struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
134struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
135struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
136struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
137struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
138struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
139struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
140struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
141struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one           (ggml_metal_library_t lib, enum ggml_op op);
142struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
143struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
144struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
145struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
146struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);
147struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
148struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
149struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op);
150struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);
151struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);
152struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);
153struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);
154struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
155struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
156struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);
157struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset            (ggml_metal_library_t lib, const struct ggml_tensor * op);
158struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal       (ggml_metal_library_t lib, const struct ggml_tensor * op);
159
160struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
161        ggml_metal_library_t lib,
162        const struct ggml_tensor * op,
163        bool    has_mask,
164        int32_t ncpsg);
165
166struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
167        ggml_metal_library_t lib,
168        const struct ggml_tensor * op,
169        int32_t nqptg,
170        int32_t ncpsg);
171
172struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
173        ggml_metal_library_t lib,
174        const struct ggml_tensor * op,
175        bool    has_mask,
176        bool    has_sinks,
177        bool    has_bias,
178        bool    has_scap,
179        bool    has_kvpad,
180        int32_t nsg);
181
182struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
183        ggml_metal_library_t lib,
184        const struct ggml_tensor * op,
185        bool    has_mask,
186        bool    has_sinks,
187        bool    has_bias,
188        bool    has_scap,
189        bool    has_kvpad,
190        int32_t nsg,
191        int32_t nwg);
192
193struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
194        ggml_metal_library_t lib,
195        const struct ggml_tensor * op,
196        int32_t dv,
197        int32_t nwg);
198
199// MTLResidencySet wrapper
200
201typedef void * ggml_metal_rset_t;
202
203// a collection of residency sets (non-owning)
204typedef struct ggml_metal_rsets * ggml_metal_rsets_t;
205
206ggml_metal_rsets_t ggml_metal_rsets_init(void);
207void ggml_metal_rsets_free(ggml_metal_rsets_t rsets);
208
209//
210// device
211//
212
213struct ggml_metal_device_props {
214    int device;
215    char name[128];
216    char desc[128];
217
218    size_t max_buffer_size;
219    size_t max_working_set_size;
220    size_t max_theadgroup_memory_size;
221
222    bool has_simdgroup_reduction;
223    bool has_simdgroup_mm;
224    bool has_unified_memory;
225    bool has_bfloat;
226    bool has_tensor;
227    bool use_residency_sets;
228    bool use_shared_buffers;
229
230    bool supports_gpu_family_apple7;
231
232    int op_offload_min_batch_size;
233};
234
235typedef struct ggml_metal_event * ggml_metal_event_t;
236
237void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
238void ggml_metal_event_encode_wait  (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf);
239
240ggml_metal_device_t ggml_metal_device_init(int device);
241void ggml_metal_device_free(ggml_metal_device_t dev);
242
243ggml_metal_device_t ggml_metal_device_get(int device);
244
245void * ggml_metal_device_get_obj  (ggml_metal_device_t dev); // id<MTLDevice>
246void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue>
247
248ggml_metal_library_t ggml_metal_device_get_library(ggml_metal_device_t dev);
249
250void ggml_metal_device_rsets_add(ggml_metal_device_t dev, ggml_metal_rset_t rset);
251void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset);
252
253void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev);
254
255ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev);
256void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev);
257void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev);
258
259void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total);
260bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op);
261
262const struct ggml_metal_device_props * ggml_metal_device_get_props(ggml_metal_device_t dev);
263
264//
265// device buffers
266//
267
268typedef struct ggml_metal_buffer * ggml_metal_buffer_t;
269
270ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, bool shared);
271ggml_metal_buffer_t ggml_metal_buffer_map (ggml_metal_device_t dev, void * ptr, size_t size, size_t max_tensor_size);
272
273void   ggml_metal_buffer_free     (ggml_metal_buffer_t buf);
274void * ggml_metal_buffer_get_base (ggml_metal_buffer_t buf);
275bool   ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf);
276
277void   ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size);
278void   ggml_metal_buffer_set_tensor   (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
279void   ggml_metal_buffer_get_tensor   (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
280void   ggml_metal_buffer_clear        (ggml_metal_buffer_t buf, uint8_t value);
281
282// finds the Metal buffer that contains the tensor data on the GPU device
283// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
284// Metal buffer based on the host memory pointer
285//
286struct ggml_metal_buffer_id ggml_metal_buffer_get_id(ggml_metal_buffer_t buf, const struct ggml_tensor * t);
287
288#ifdef __cplusplus
289}
290#endif