1#ifndef HTP_OPS_H
2#define HTP_OPS_H
3
4#include "htp-ctx.h"
5#include "htp-msg.h"
6#include "worker-pool.h"
7
8#include <assert.h>
9#include <stdint.h>
10
11#include <hex-fastdiv.h>
12
13// ggml-common.h must be included prior to this header
14
15struct htp_spad {
16 uint8_t * data;
17 size_t stride;
18 size_t size;
19 size_t size_per_thread;
20};
21
22struct htp_ops_context {
23 struct htp_context * ctx;
24
25 enum htp_op op;
26 int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)];
27
28 struct htp_tensor src0;
29 struct htp_tensor src1;
30 struct htp_tensor src2;
31 struct htp_tensor src3;
32 struct htp_tensor src4;
33 struct htp_tensor dst;
34
35 struct htp_spad src0_spad;
36 struct htp_spad src1_spad;
37 struct htp_spad src2_spad;
38 struct htp_spad src3_spad;
39 struct htp_spad dst_spad;
40
41 worker_pool_context_t * wpool; // worker pool
42 uint32_t n_threads; // num threads
43
44 uint32_t src0_nrows_per_thread;
45 uint32_t src1_nrows_per_thread;
46
47 struct fastdiv_values src0_div1; // fastdiv values for ne1
48 struct fastdiv_values src0_div2; // fastdiv values for ne2
49 struct fastdiv_values src0_div3; // fastdiv values for ne3
50 struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1
51
52 struct fastdiv_values src1_div1; // fastdiv values for ne1
53 struct fastdiv_values src1_div2; // fastdiv values for ne2
54 struct fastdiv_values src1_div3; // fastdiv values for ne3
55 struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
56
57 struct fastdiv_values src3_div1; // fastdiv values for ne1
58 struct fastdiv_values src3_div2; // fastdiv values for ne2
59 struct fastdiv_values src3_div3; // fastdiv values for ne3
60 struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
61
62 struct fastdiv_values broadcast_rk2;
63 struct fastdiv_values broadcast_rk3;
64 struct fastdiv_values broadcast_rv2;
65 struct fastdiv_values broadcast_rv3;
66
67 struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
68 struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
69
70 struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
71 struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
72
73 uint32_t flags;
74};
75
76int op_matmul(struct htp_ops_context * octx);
77int op_matmul_id(struct htp_ops_context * octx);
78int op_binary(struct htp_ops_context * octx);
79int op_unary(struct htp_ops_context * octx);
80int op_sum_rows(struct htp_ops_context * octx);
81int op_activations(struct htp_ops_context * octx);
82int op_softmax(struct htp_ops_context * octx);
83int op_add_id(struct htp_ops_context * octx);
84int op_rope(struct htp_ops_context * octx);
85int op_flash_attn_ext(struct htp_ops_context * octx);
86int op_set_rows(struct htp_ops_context * octx);
87int op_get_rows(struct htp_ops_context * octx);
88int op_cpy(struct htp_ops_context * octx);
89int op_argsort(struct htp_ops_context * octx);
90
91#endif /* HTP_OPS_H */