1#pragma once
2
3#include "ggml-metal-device.h"
4
5#ifdef __cplusplus
6extern "C" {
7#endif
8
9typedef struct ggml_metal_op * ggml_metal_op_t;
10
11ggml_metal_op_t ggml_metal_op_init(
12 ggml_metal_device_t dev,
13 ggml_metal_cmd_buf_t cmd_buf,
14 struct ggml_cgraph * gf,
15 int idx_start,
16 int idx_end,
17 bool use_fusion,
18 bool use_concurrency,
19 bool use_capture,
20 int debug_graph,
21 int debug_fusion);
22
23void ggml_metal_op_free(ggml_metal_op_t ctx);
24
25int ggml_metal_op_n_nodes(ggml_metal_op_t ctx);
26
27int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
28
29//
30// available ops:
31//
32
33// tokens per expert
34size_t ggml_metal_op_mul_mat_id_extra_tpe(const struct ggml_tensor * op);
35
36// id map [n_tokens, n_expert]
37size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op);
38
39// return true if we should use the FA vector kernel for this op
40bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op);
41
42size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op);
43size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op);
44size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
45
46int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
47int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
48int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
49int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
50int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
51int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
52int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
53int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx);
54int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
55int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
56int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx);
57int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
58int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
59int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
60int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
61int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
62int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
63int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
64int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
65int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx);
66int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx);
67int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
68int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
69int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
70int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
71int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
72int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
73int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
74int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx);
75int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx);
76int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
77int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
78int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx);
79int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx);
80int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
81int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
82int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
83int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
84int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
85int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
86int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
87int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
88int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
89int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
90
91#ifdef __cplusplus
92}
93#endif