summaryrefslogtreecommitdiff
path: root/llama.cpp/src/llama-batch.h
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/src/llama-batch.h
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/src/llama-batch.h')
-rw-r--r--llama.cpp/src/llama-batch.h173
1 files changed, 173 insertions, 0 deletions
diff --git a/llama.cpp/src/llama-batch.h b/llama.cpp/src/llama-batch.h
new file mode 100644
index 0000000..8e6fac0
--- /dev/null
+++ b/llama.cpp/src/llama-batch.h
@@ -0,0 +1,173 @@
+#pragma once
+
+#include "llama.h"
+
+#include "llama-cparams.h"
+
+#include <array>
+#include <vector>
+#include <set>
+#include <bitset>
+#include <memory>
+#include <unordered_map>
+
+// keep this struct lightweight
+struct llama_ubatch {
+ bool equal_seqs() const {
+ return b_equal_seqs != 0;
+ }
+
+ // typical for M-RoPE cases:
+ // 0 - sequantial position of the tokens/embeddings in the sequence
+ // 1 - y position in the image
+ // 2 - x position in the image
+ // 3 - other
+ bool is_pos_2d() const {
+ // TODO @ngxson : we may need to check for model arch when more models use >1 positions
+ return n_pos >= 3;
+ }
+
+ uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
+ // otherwise address sanitizer complains
+ // TODO: whole_seqs for embeddings?
+
+ uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
+ uint32_t n_seq_tokens; // tokens per sequence set
+ uint32_t n_seqs; // sequence sets in the ubatch
+ uint32_t n_seqs_unq; // unique sequence ids in the ubatch
+ uint32_t n_pos; // number of position inputs for each token/embedding
+
+ // seq_id_unq: unique sequence ids in the ubatch
+ // seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
+ // used for extracting sequence pooled embeddings
+
+ // // size | idx | val
+ llama_token * token; // [n_tokens] | i | id, token
+ float * embd; // [n_embd, n_tokens] | i | embd
+ llama_pos * pos; // [n_tokens*n_pos] | i | pos
+ int32_t * n_seq_id; // [n_tokens] | i | -
+ llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
+ llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
+ int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
+ int8_t * output; // [n_tokens] | i | -
+
+ struct data_t {
+ std::vector<llama_token> token;
+ std::vector<float> embd;
+ std::vector<llama_pos> pos;
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id *> seq_id; // these point into the seq_id_data below
+ std::vector<llama_seq_id> seq_id_unq;
+ std::vector<int32_t> seq_idx;
+ std::vector<int8_t> output;
+
+ std::vector<llama_seq_id> seq_id_data;
+ };
+
+ // the llama_ubatch pointers above point to this data if set. otherwise - point to external non-owning data
+ std::shared_ptr<data_t> data;
+};
+
+// a helper for sanitizing, fulfilling and splitting a batch
+class llama_batch_allocr {
+public:
+ llama_batch_allocr(uint32_t n_pos_per_embd);
+
+ // sanitize and auto-gen missing data in the input batch
+ // memory is optional. if provided will be used to check for sequence continuity and to determine the positions
+ bool init(
+ const llama_batch & batch_inp,
+ const llama_vocab & vocab,
+ const llama_memory_i * memory,
+ uint32_t n_embd,
+ uint32_t n_seq_max,
+ bool output_all);
+
+ const llama_batch & get_batch() const;
+
+ uint32_t get_n_tokens() const;
+ uint32_t get_n_outputs() const;
+ uint32_t get_n_used() const;
+
+ // the array of output indices in the order they were encountered during the ubatch splitting
+ std::vector<int32_t> & get_out_ids();
+
+ // min/max positions of each sequence in the current ubatch
+ llama_pos seq_pos_min(llama_seq_id seq_id) const;
+ llama_pos seq_pos_max(llama_seq_id seq_id) const;
+
+ // call once before splitting the batch to reset the internal state
+ void split_reset();
+
+ // simple split, unknown number of sequence sets of unequal lengths
+ llama_ubatch split_simple(uint32_t n_ubatch);
+
+ // make ubatches of equal-length sequences sets
+ // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
+ llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
+
+ // sequence-set-wise split - each ubatch contains a single sequence-set
+ llama_ubatch split_seq(uint32_t n_ubatch);
+
+ // a helper method for creating a well-defined ubatch of tokens
+ // TODO: support embeddings if needed in the future
+ llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
+
+private:
+ void clear();
+
+ // create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
+ // return llama_ubatch.n_tokens == 0 if the entire batch was consumed
+ llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
+
+ // for debugging, start with LLAMA_BATCH_DEBUG=2
+ void ubatch_print(const llama_ubatch & ubatch, int debug);
+
+ llama_batch batch;
+
+ // only for debugging purposes
+ const llama_vocab * vocab;
+
+ // TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
+ // ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
+ const uint32_t n_pos_per_embd;
+
+ uint32_t n_embd;
+ uint32_t n_seq_max;
+ uint32_t n_outputs;
+
+ std::array<llama_seq_id, 1> seq_id_0 = {{ 0 }}; // default sequence id
+
+ std::vector<llama_pos> pos;
+ std::vector<int32_t> n_seq_id;
+ std::vector<llama_seq_id *> seq_id;
+ std::vector<llama_seq_id> seq_id_unq;
+ std::vector<int32_t> seq_idx;
+ std::vector<int8_t> output;
+
+ using pos_set_t = std::set<llama_pos>;
+ using seq_cpl_t = std::vector<bool>;
+
+ // helper flag to quickly determine if there are any coupled sequences in the batch
+ bool has_cpl = false;
+
+ std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
+ std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
+
+ using idx_vec_t = std::vector<int32_t>;
+ using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
+
+ std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
+
+ std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
+
+ // batch indices of the output
+ std::vector<int32_t> out_ids;
+
+ uint32_t n_used;
+
+ // used[i] indicates if token i has already been used in a previous ubatch
+ std::vector<bool> used;
+
+ int debug;
+};