1#ifndef HTP_OPS_H
 2#define HTP_OPS_H
 3
 4#include "htp-ctx.h"
 5#include "htp-msg.h"
 6#include "worker-pool.h"
 7
 8#include <assert.h>
 9#include <stdint.h>
10
11#include <hex-fastdiv.h>
12
13// ggml-common.h must be included prior to this header
14
15struct htp_spad {
16    uint8_t * data;
17    size_t    stride;
18    size_t    size;
19    size_t    size_per_thread;
20};
21
22struct htp_ops_context {
23    struct htp_context * ctx;
24
25    enum htp_op op;
26    int32_t     op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
27
28    struct htp_tensor src0;
29    struct htp_tensor src1;
30    struct htp_tensor src2;
31    struct htp_tensor src3;
32    struct htp_tensor src4;
33    struct htp_tensor dst;
34
35    struct htp_spad src0_spad;
36    struct htp_spad src1_spad;
37    struct htp_spad src2_spad;
38    struct htp_spad src3_spad;
39    struct htp_spad dst_spad;
40
41    worker_pool_context_t * wpool;      // worker pool
42    uint32_t                n_threads;  // num threads
43
44    uint32_t src0_nrows_per_thread;
45    uint32_t src1_nrows_per_thread;
46
47    struct fastdiv_values src0_div1;  // fastdiv values for ne1
48    struct fastdiv_values src0_div2;  // fastdiv values for ne2
49    struct fastdiv_values src0_div3;  // fastdiv values for ne3
50    struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
51
52    struct fastdiv_values src1_div1;  // fastdiv values for ne1
53    struct fastdiv_values src1_div2;  // fastdiv values for ne2
54    struct fastdiv_values src1_div3;  // fastdiv values for ne3
55    struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
56
57    struct fastdiv_values src3_div1;  // fastdiv values for ne1
58    struct fastdiv_values src3_div2;  // fastdiv values for ne2
59    struct fastdiv_values src3_div3;  // fastdiv values for ne3
60    struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
61
62    struct fastdiv_values broadcast_rk2;
63    struct fastdiv_values broadcast_rk3;
64    struct fastdiv_values broadcast_rv2;
65    struct fastdiv_values broadcast_rv3;
66
67    struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
68    struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
69
70    struct fastdiv_values get_rows_div_ne10;      // fastdiv values for ne10
71    struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
72
73    uint32_t flags;
74};
75
76int op_matmul(struct htp_ops_context * octx);
77int op_matmul_id(struct htp_ops_context * octx);
78int op_binary(struct htp_ops_context * octx);
79int op_unary(struct htp_ops_context * octx);
80int op_sum_rows(struct htp_ops_context * octx);
81int op_activations(struct htp_ops_context * octx);
82int op_softmax(struct htp_ops_context * octx);
83int op_add_id(struct htp_ops_context * octx);
84int op_rope(struct htp_ops_context * octx);
85int op_flash_attn_ext(struct htp_ops_context * octx);
86int op_set_rows(struct htp_ops_context * octx);
87int op_get_rows(struct htp_ops_context * octx);
88int op_cpy(struct htp_ops_context * octx);
89int op_argsort(struct htp_ops_context * octx);
90
91#endif /* HTP_OPS_H */