1#ifndef MTMD_HELPER_H
 2#define MTMD_HELPER_H
 3
 4#include "ggml.h"
 5#include "llama.h"
 6#include "mtmd.h"
 7
 8#include <stddef.h>
 9#include <stdint.h>
10#include <stdbool.h>
11
12#ifdef __cplusplus
13extern "C" {
14#endif
15
16//
17// libmtmd helper functions
18//
19// Please note that these helpers are not guaranteed to be stable.
20// BREAKING CHANGES are expected.
21//
22
23// Set callback for all future logging events.
24// If this is not called, or NULL is supplied, everything is output on stderr.
25// Note: this also call mtmd_log_set() internally
26MTMD_API void mtmd_helper_log_set(ggml_log_callback log_callback, void * user_data);
27
28// helper function to construct a mtmd_bitmap from a file
29// it calls mtmd_helper_bitmap_init_from_buf() internally
30// returns nullptr on failure
31// this function is thread-safe
32MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname);
33
34// helper function to construct a mtmd_bitmap from a buffer containing a file
35// supported formats:
36//     image: formats supported by stb_image: jpg, png, bmp, gif, etc.
37//     audio: formats supported by miniaudio: wav, mp3, flac
38// note: audio files will be auto-detected based on magic bytes
39// returns nullptr on failure
40// this function is thread-safe
41MTMD_API mtmd_bitmap * mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, const unsigned char * buf, size_t len);
42
43// helper to count the total number of tokens from a list of chunks, useful to keep track of KV cache
44MTMD_API size_t mtmd_helper_get_n_tokens(const mtmd_input_chunks * chunks);
45
46// helper to count the total position of tokens from a list of chunks, useful to keep track of n_past
47// normally, n_pos is equal to n_tokens, but for M-RoPE it is different
48MTMD_API llama_pos mtmd_helper_get_n_pos(const mtmd_input_chunks * chunks);
49
50// helper function that automatically:
51// 1. run llama_decode() on text chunks
52// 2. run mtmd_encode() on image chunks, then mtmd_get_output_embd() and then llama_decode()
53// if any of the mtmd_encode() or llama_decode() calls return non-zero, stop and forward the error
54// otherwise, returns 0 on success
55// this function is NOT thread-safe
56MTMD_API int32_t mtmd_helper_eval_chunks(mtmd_context * ctx,
57                                         struct llama_context * lctx,
58                                         const mtmd_input_chunks * chunks,
59                                         llama_pos n_past,
60                                         llama_seq_id seq_id,
61                                         int32_t n_batch,
62                                         bool logits_last,
63                                         llama_pos * new_n_past);
64
65// works like mtmd_helper_eval_chunks(), but only for a single chunk
66// this function is NOT thread-safe
67MTMD_API int32_t mtmd_helper_eval_chunk_single(mtmd_context * ctx,
68                                               struct llama_context * lctx,
69                                               const mtmd_input_chunk * chunk,
70                                               llama_pos n_past,
71                                               llama_seq_id seq_id,
72                                               int32_t n_batch,
73                                               bool logits_last,
74                                               llama_pos * new_n_past);
75
76// helper function to decode an image whose embeddings have already been calculated
77// this helper will handle batching and pre/post decoding setup (for ex. gemma 3 requires non-causal attention)
78// ret 0 on success, -1 on chunk not being a valid image chunk, 1 on decode failure
79MTMD_API int32_t mtmd_helper_decode_image_chunk(mtmd_context * ctx,
80                                                struct llama_context * lctx,
81                                                const mtmd_input_chunk * chunk,
82                                                float * encoded_embd,
83                                                llama_pos n_past,
84                                                llama_seq_id seq_id,
85                                                int32_t n_batch,
86                                                llama_pos * new_n_past);
87
88#ifdef __cplusplus
89} // extern "C"
90#endif
91
92//
93// C++ wrappers
94//
95
96#endif