1#include "arg.h"
2#include "common.h"
3#include "log.h"
4#include "llama.h"
5
6#include <ctime>
7#include <algorithm>
8
9#if defined(_MSC_VER)
10#pragma warning(disable: 4244 4267) // possible loss of data
11#endif
12
13static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
14 std::vector<std::string> lines;
15 size_t start = 0;
16 size_t end = s.find(separator);
17
18 while (end != std::string::npos) {
19 lines.push_back(s.substr(start, end - start));
20 start = end + separator.length();
21 end = s.find(separator, start);
22 }
23
24 lines.push_back(s.substr(start)); // Add the last part
25
26 return lines;
27}
28
29static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
30 size_t n_tokens = tokens.size();
31 for (size_t i = 0; i < n_tokens; i++) {
32 common_batch_add(batch, tokens[i], i, { seq_id }, true);
33 }
34}
35
36static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) {
37 const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
38
39 // clear previous kv_cache values (irrelevant for embeddings)
40 llama_memory_clear(llama_get_memory(ctx), true);
41
42 // run model
43 LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
44 if (llama_decode(ctx, batch) < 0) {
45 LOG_ERR("%s : failed to process\n", __func__);
46 }
47
48 for (int i = 0; i < batch.n_tokens; i++) {
49 if (!batch.logits[i]) {
50 continue;
51 }
52
53 const float * embd = nullptr;
54 int embd_pos = 0;
55
56 if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
57 // try to get token embeddings
58 embd = llama_get_embeddings_ith(ctx, i);
59 embd_pos = i;
60 GGML_ASSERT(embd != NULL && "failed to get token embeddings");
61 } else {
62 // try to get sequence embeddings - supported only when pooling_type is not NONE
63 embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
64 embd_pos = batch.seq_id[i][0];
65 GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
66 }
67
68 float * out = output + embd_pos * n_embd_out;
69 common_embd_normalize(embd, out, n_embd_out, embd_norm);
70 }
71}
72
73// plain, pipe-friendly output: one embedding per line
74static void print_raw_embeddings(const float * emb,
75 int n_embd_count,
76 int n_embd,
77 const llama_model * model,
78 enum llama_pooling_type pooling_type,
79 int embd_normalize) {
80 const uint32_t n_cls_out = llama_model_n_cls_out(model);
81 const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK);
82 const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd;
83
84 for (int j = 0; j < n_embd_count; ++j) {
85 for (int i = 0; i < cols; ++i) {
86 if (embd_normalize == 0) {
87 LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
88 } else {
89 LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : ""));
90 }
91 }
92 LOG("\n");
93 }
94}
95
96int main(int argc, char ** argv) {
97 common_params params;
98
99 if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) {
100 return 1;
101 }
102
103 common_init();
104
105 params.embedding = true;
106
107 // get max number of sequences per batch
108 const int n_seq_max = llama_max_parallel_sequences();
109
110 // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
111 // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
112 // in order to support any number of prompts
113 if (params.n_parallel == 1) {
114 LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
115 params.kv_unified = true;
116 params.n_parallel = n_seq_max;
117 }
118
119 // utilize the full context
120 if (params.n_batch < params.n_ctx) {
121 LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx);
122 params.n_batch = params.n_ctx;
123 }
124
125 // for non-causal models, batch size must be equal to ubatch size
126 if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) {
127 params.n_ubatch = params.n_batch;
128 }
129
130 llama_backend_init();
131 llama_numa_init(params.numa);
132
133 // load the model
134 auto llama_init = common_init_from_params(params);
135
136 auto * model = llama_init->model();
137 auto * ctx = llama_init->context();
138
139 if (model == NULL) {
140 LOG_ERR("%s: unable to load model\n", __func__);
141 return 1;
142 }
143
144 const llama_vocab * vocab = llama_model_get_vocab(model);
145
146 const int n_ctx_train = llama_model_n_ctx_train(model);
147 const int n_ctx = llama_n_ctx(ctx);
148
149 const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
150
151 if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
152 LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__);
153 return 1;
154 }
155
156 if (n_ctx > n_ctx_train) {
157 LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n",
158 __func__, n_ctx_train, n_ctx);
159 }
160
161 // print system information
162 {
163 LOG_INF("\n");
164 LOG_INF("%s\n", common_params_get_system_info(params).c_str());
165 }
166
167 // split the prompt into lines
168 std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
169
170 // max batch size
171 const uint64_t n_batch = params.n_batch;
172
173 // get added sep and eos token, if any
174 const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : "";
175 const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : "";
176 const char * rerank_prompt = llama_model_chat_template(model, "rerank");
177
178 // tokenize the prompts and trim
179 std::vector<std::vector<int32_t>> inputs;
180 for (const auto & prompt : prompts) {
181 std::vector<llama_token> inp;
182
183 // split classification pairs and insert expected separator tokens
184 if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) {
185 std::vector<std::string> pairs = split_lines(prompt, params.cls_sep);
186 if (rerank_prompt != nullptr) {
187 const std::string query = pairs[0];
188 const std::string doc = pairs[1];
189 std::string final_prompt = rerank_prompt;
190 string_replace_all(final_prompt, "{query}" , query);
191 string_replace_all(final_prompt, "{document}", doc );
192 inp = common_tokenize(vocab, final_prompt, true, true);
193 } else {
194 std::string final_prompt;
195 for (size_t i = 0; i < pairs.size(); i++) {
196 final_prompt += pairs[i];
197 if (i != pairs.size() - 1) {
198 if (!added_eos_token.empty()) {
199 final_prompt += added_eos_token;
200 }
201 if (!added_sep_token.empty()) {
202 final_prompt += added_sep_token;
203 }
204 }
205 }
206 inp = common_tokenize(ctx, final_prompt, true, true);
207 }
208 } else {
209 inp = common_tokenize(ctx, prompt, true, true);
210 }
211 if (inp.size() > n_batch) {
212 LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n",
213 __func__, (long long int) inp.size(), (long long int) n_batch);
214 return 1;
215 }
216 inputs.push_back(inp);
217 }
218
219 // check if the last token is SEP/EOS
220 // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
221 for (auto & inp : inputs) {
222 if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) {
223 LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__);
224 LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
225 }
226 }
227
228 // tokenization stats
229 if (params.verbose_prompt) {
230 for (int i = 0; i < (int) inputs.size(); i++) {
231 LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
232 LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
233 for (int j = 0; j < (int) inputs[i].size(); j++) {
234 LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str());
235 }
236 LOG("\n\n");
237 }
238 }
239
240 // initialize batch
241 const int n_prompts = prompts.size();
242 struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
243
244 // count number of embeddings
245 int n_embd_count = 0;
246 if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
247 for (int k = 0; k < n_prompts; k++) {
248 n_embd_count += inputs[k].size();
249 }
250 } else {
251 n_embd_count = n_prompts;
252 }
253
254 // allocate output
255 const int n_embd_out = llama_model_n_embd_out(model);
256 std::vector<float> embeddings(n_embd_count * n_embd_out, 0);
257 float * emb = embeddings.data();
258
259 // break into batches
260 int e = 0; // number of embeddings already stored
261 int s = 0; // number of prompts in current batch
262 for (int k = 0; k < n_prompts; k++) {
263 // clamp to n_batch tokens
264 auto & inp = inputs[k];
265
266 const uint64_t n_toks = inp.size();
267
268 // encode if at capacity
269 if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
270 float * out = emb + e * n_embd_out;
271 batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
272 e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
273 s = 0;
274 common_batch_clear(batch);
275 }
276
277 // add to batch
278 batch_add_seq(batch, inp, s);
279 s += 1;
280 }
281
282 // final batch
283 float * out = emb + e * n_embd_out;
284 batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize);
285
286 if (params.embd_out.empty()) {
287 LOG("\n");
288
289 if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
290 for (int j = 0; j < n_embd_count; j++) {
291 LOG("embedding %d: ", j);
292 for (int i = 0; i < std::min(3, n_embd_out); i++) {
293 if (params.embd_normalize == 0) {
294 LOG("%6.0f ", emb[j * n_embd_out + i]);
295 } else {
296 LOG("%9.6f ", emb[j * n_embd_out + i]);
297 }
298 }
299 LOG(" ... ");
300 for (int i = n_embd_out - 3; i < n_embd_out; i++) {
301 if (params.embd_normalize == 0) {
302 LOG("%6.0f ", emb[j * n_embd_out + i]);
303 } else {
304 LOG("%9.6f ", emb[j * n_embd_out + i]);
305 }
306 }
307 LOG("\n");
308 }
309 } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) {
310 const uint32_t n_cls_out = llama_model_n_cls_out(model);
311 std::vector<std::string> cls_out_labels;
312
313 for (uint32_t i = 0; i < n_cls_out; i++) {
314 const char * label = llama_model_cls_label(model, i);
315 const std::string label_i(label == nullptr ? "" : label);
316 cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i);
317 }
318
319 for (int j = 0; j < n_embd_count; j++) {
320 for (uint32_t i = 0; i < n_cls_out; i++) {
321 // NOTE: if you change this log - update the tests in ci/run.sh
322 if (n_cls_out == 1) {
323 LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]);
324 } else {
325 LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str());
326 }
327 }
328 }
329 } else {
330 // print the first part of the embeddings or for a single prompt, the full embedding
331 for (int j = 0; j < n_prompts; j++) {
332 LOG("embedding %d: ", j);
333 for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) {
334 if (params.embd_normalize == 0) {
335 LOG("%6.0f ", emb[j * n_embd_out + i]);
336 } else {
337 LOG("%9.6f ", emb[j * n_embd_out + i]);
338 }
339 }
340 LOG("\n");
341 }
342
343 // print cosine similarity matrix
344 if (n_prompts > 1) {
345 LOG("\n");
346 LOG("cosine similarity matrix:\n\n");
347 for (int i = 0; i < n_prompts; i++) {
348 LOG("%6.6s ", prompts[i].c_str());
349 }
350 LOG("\n");
351 for (int i = 0; i < n_prompts; i++) {
352 for (int j = 0; j < n_prompts; j++) {
353 float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
354 LOG("%6.2f ", sim);
355 }
356 LOG("%1.10s", prompts[i].c_str());
357 LOG("\n");
358 }
359 }
360 }
361 }
362
363 if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
364 const bool notArray = params.embd_out != "array";
365
366 LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
367 for (int j = 0;;) { // at least one iteration (one prompt)
368 if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
369 LOG("[");
370 for (int i = 0;;) { // at least one iteration (n_embd > 0)
371 LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]);
372 i++;
373 if (i < n_embd_out) LOG(","); else break;
374 }
375 LOG(notArray ? "]\n }" : "]");
376 j++;
377 if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break;
378 }
379 LOG(notArray ? "\n ]" : "]\n");
380
381 if (params.embd_out == "json+" && n_prompts > 1) {
382 LOG(",\n \"cosineSimilarity\": [\n");
383 for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
384 LOG(" [");
385 for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
386 float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out);
387 LOG("%6.2f", sim);
388 j++;
389 if (j < n_embd_count) LOG(", "); else break;
390 }
391 LOG(" ]");
392 i++;
393 if (i < n_embd_count) LOG(",\n"); else break;
394 }
395 LOG("\n ]");
396 }
397
398 if (notArray) LOG("\n}\n");
399 } else if (params.embd_out == "raw") {
400 print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize);
401 }
402
403 LOG("\n");
404 llama_perf_context_print(ctx);
405
406 // clean up
407 llama_batch_free(batch);
408 llama_backend_free();
409
410 return 0;
411}