1#include "llama-batch.h"
  2
  3#include "llama-impl.h"
  4#include "llama-vocab.h"
  5#include "llama-memory.h"
  6
  7#include <cassert>
  8#include <cstring>
  9#include <algorithm>
 10#include <sstream>
 11
 12llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
 13    const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
 14    debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
 15
 16    seq_pos.resize(LLAMA_MAX_SEQ);
 17    seq_cpl.resize(LLAMA_MAX_SEQ);
 18    for (auto & cur : seq_cpl) {
 19        cur.resize(LLAMA_MAX_SEQ);
 20    }
 21
 22    seq_idx.resize(LLAMA_MAX_SEQ, -1);
 23}
 24
 25bool llama_batch_allocr::init(
 26        const llama_batch & batch_inp,
 27        const llama_vocab & vocab,
 28        const llama_memory_i * memory,
 29        uint32_t n_embd,
 30        uint32_t n_seq_max,
 31        bool output_all) {
 32    clear();
 33
 34    batch = batch_inp;
 35
 36    this->vocab = &vocab;
 37
 38    GGML_ASSERT(batch.n_tokens > 0);
 39
 40    //
 41    // validate input batch
 42    //
 43
 44    if (n_seq_max > LLAMA_MAX_SEQ) {
 45        LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
 46        return false;
 47    }
 48
 49    if (batch.token) {
 50        for (int32_t i = 0; i < batch.n_tokens; ++i) {
 51            if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
 52                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
 53                return false;
 54            }
 55        }
 56    }
 57
 58    if (batch.seq_id) {
 59        for (int32_t i = 0; i < batch.n_tokens; ++i) {
 60            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
 61                if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
 62                    LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
 63                    return false;
 64                }
 65            }
 66        }
 67    }
 68
 69    //
 70    // auto-generate missing fields
 71    //
 72
 73    if (!batch.n_seq_id) {
 74        n_seq_id.resize(batch.n_tokens);
 75        for (int32_t i = 0; i < batch.n_tokens; i++) {
 76            n_seq_id[i] = seq_id_0.size();
 77        }
 78        batch.n_seq_id = n_seq_id.data();
 79    }
 80
 81    if (!batch.seq_id) {
 82        seq_id.resize(batch.n_tokens + 1);
 83        seq_id[batch.n_tokens] = NULL;
 84        for (int32_t i = 0; i < batch.n_tokens; i++) {
 85            seq_id[i] = seq_id_0.data();
 86        }
 87        batch.seq_id = seq_id.data();
 88    }
 89
 90    if (!batch.pos) {
 91        pos.resize(batch.n_tokens);
 92
 93        // initialize the starting position for each sequence based on the positions in the memory
 94        llama_pos p0[LLAMA_MAX_SEQ];
 95        for (uint32_t s = 0; s < n_seq_max; ++s) {
 96            if (!memory) {
 97                // if no memory -> start from 0
 98                p0[s] = 0;
 99            } else {
100                p0[s] = memory->seq_pos_max(s) + 1;
101            }
102        }
103
104        for (int32_t i = 0; i < batch.n_tokens; i++) {
105            const llama_seq_id seq_id = batch.seq_id[i][0];
106
107            pos[i] = p0[seq_id];
108
109            // update the starting position for all sequences that are assigned to the this token
110            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
111                const llama_seq_id seq_id = batch.seq_id[i][s];
112
113                p0[seq_id] = pos[i] + 1;
114            }
115        }
116
117        batch.pos = pos.data();
118    }
119
120    if (!batch.logits) {
121        if (output_all) {
122            // return the output for all tokens
123            output.resize(batch.n_tokens, true);
124        } else {
125            // return the output only for the last token
126            output.resize(batch.n_tokens, false);
127            output[output.size() - 1] = true;
128        }
129
130        batch.logits = output.data();
131    } else if (output_all) {
132        bool warn = false;
133
134        for (int32_t i = 0; i < batch.n_tokens; ++i) {
135            if (batch.logits[i] == 0) {
136                warn = true;
137            }
138        }
139
140        if (warn) {
141            LLAMA_LOG_WARN("%s: embeddings required but some input tokens were not marked as outputs -> overriding\n", __func__);
142
143            output.resize(batch.n_tokens, true);
144            batch.logits = output.data();
145        }
146    }
147
148    //
149    // compute stats
150    //
151
152    this->n_embd    = n_embd;
153    this->n_seq_max = n_seq_max;
154
155    // count the outputs in this batch
156    for (int32_t i = 0; i < batch.n_tokens; ++i) {
157        n_outputs += batch.logits[i] != 0;
158    }
159
160    has_cpl = false;
161
162    // determine coupled sequences
163    // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
164    for (int32_t i = 0; i < batch.n_tokens; ++i) {
165        const llama_seq_id s0 = batch.seq_id[i][0];
166
167        for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
168            const llama_seq_id s1 = batch.seq_id[i][s];
169
170            seq_pos[s1].insert(batch.pos[i]);
171
172            if (s > 0) {
173                // mark that sequence s1 is coupled to s0
174                seq_cpl[s1][s0] = true;
175
176                // note: tracking the other way around is not necessary for now
177                //seq_cpl[s0][s1] = true;
178
179                has_cpl = true;
180            }
181        }
182    }
183
184    // precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
185    {
186        seq_set_t seq_set_unq;
187
188        for (int32_t i = 0; i < batch.n_tokens; ++i) {
189            seq_set_t cur;
190            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
191                const llama_seq_id seq_id = batch.seq_id[i][s];
192
193                cur        .set(seq_id);
194                seq_set_unq.set(seq_id);
195            }
196
197            seq_set.push_back(cur);
198            seq_set_map[cur].push_back(i);
199        }
200
201        for (uint32_t s = 0; s < n_seq_max; ++s) {
202            if (seq_set_unq.test(s)) {
203                seq_idx[s] = seq_id_unq.size();
204                seq_id_unq.push_back(s);
205            }
206        }
207    }
208
209    if (debug > 0) {
210        LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
211
212        llama_ubatch ubatch {
213            /*.b_equal_seqs =*/ false,
214            /*.n_tokens     =*/ (uint32_t) batch.n_tokens,
215            /*.n_seq_tokens =*/ (uint32_t) 1,
216            /*.n_seqs       =*/ (uint32_t) batch.n_tokens,
217            /*.n_seqs_unq   =*/ (uint32_t) this->seq_id_unq.size(),
218            /*.n_pos        =*/ n_pos_per_embd,
219            /*.token        =*/ batch.token,
220            /*.embd         =*/ batch.embd,
221            /*.pos          =*/ batch.pos,
222            /*.n_seq_id     =*/ batch.n_seq_id,
223            /*.seq_id       =*/ batch.seq_id,
224            /*.seq_id_unq   =*/ this->seq_id_unq.data(),
225            /*.seq_idx      =*/ this->seq_idx.data(),
226            /*.output       =*/ batch.logits,
227            /*.data         =*/ {},
228        };
229
230        ubatch_print(ubatch, debug);
231
232        LLAMA_LOG_DEBUG("%s:   seq       = [\n", __func__);
233        for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
234            if (seq_pos[s0].empty()) {
235                continue;
236            }
237
238            std::stringstream ss;
239            for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
240                if (seq_cpl[s0][s1]) {
241                    ss << s1 << " ";
242                }
243            }
244
245            LLAMA_LOG_DEBUG("%s:  %4d: pos = [%4d, %4d], cpl = %s\n",
246                    __func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
247        }
248        LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
249    }
250
251    //
252    // consistency checks
253    //
254
255    if (n_pos_per_embd > 1) {
256        // M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
257        for (uint32_t s = 0; s < n_seq_max; ++s) {
258            if (seq_pos[s].empty()) {
259                continue;
260            }
261
262            const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
263
264            if (batch.token) {
265                if (p0 >= 0 && p0 >= seq_pos_min(s)) {
266                    LLAMA_LOG_ERROR(
267                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
268                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
269                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
270                            " for M-RoPE, it is required that the position satisfies: X < Y\n",
271                            __func__, s, s, p0, s, seq_pos_min(s));
272
273                    return false;
274                }
275            } else {
276                // embedding inputs can have overlapping positions
277                if (p0 >= 0 && p0 > seq_pos_min(s)) {
278                    LLAMA_LOG_ERROR(
279                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
280                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
281                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
282                            " for M-RoPE, it is required that the position satisfies: X <= Y\n",
283                            __func__, s, s, p0, s, seq_pos_min(s));
284
285                    return false;
286                }
287            }
288        }
289    } else {
290        for (uint32_t s = 0; s < n_seq_max; ++s) {
291            if (seq_pos[s].empty()) {
292                continue;
293            }
294
295            const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
296
297            if (p0 >= 0) {
298                bool ok = true;
299
300                if (seq_pos_min(s) != p0 + 1) {
301                    ok = false;
302                }
303
304                if (!ok) {
305                    LLAMA_LOG_ERROR(
306                            "%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
307                            " - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
308                            " - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
309                            " it is required that the sequence positions remain consecutive: Y = X + 1\n",
310                            __func__, s, s, p0, s, seq_pos_min(s));
311
312                    return false;
313                }
314            }
315
316            if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
317                LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
318                return false;
319            }
320        }
321    }
322
323    if (memory) {
324        for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
325            for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
326                if (seq_cpl[s0][s1]) {
327                    if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
328                        memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
329                        LLAMA_LOG_ERROR("%s: sequence %d is coupled to %d in the input batch, but have divereged\n", __func__, s0, s1);
330                        return false;
331                    }
332                }
333            }
334        }
335    }
336
337    // disallow partial sequence sub-sets:
338    //
339    // invalid:          x
340    //            i: 0 1 2 ...
341    // ---------------------------------------
342    // seq_id[i][0]: 0 0 1
343    // seq_id[i][1]: 1 1 2
344    // seq_id[i][2]: 2
345    //
346    // disallow decreasing sequence positions:
347    //
348    // invalid:                  x
349    //            i: 0 1 2 3 4 5 6 ...
350    // ---------------------------------------
351    //       pos[i]: 4 5 0 1 6 2 3
352    // seq_id[i][0]: 0 0 1 1 0 1 0
353    //
354    {
355        seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
356        for (uint32_t s = 0; s < n_seq_max; ++s) {
357            cur_seq_set[s].set();
358        }
359
360        llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
361        for (uint32_t s = 0; s < n_seq_max; ++s) {
362            cur_seq_pos[s] = -1;
363        }
364
365        for (int32_t i = 0; i < batch.n_tokens; ++i) {
366            const llama_pos pos = batch.pos[i];
367
368            for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
369                const llama_seq_id seq_id = batch.seq_id[i][s];
370
371                cur_seq_set[seq_id] &= seq_set[i];
372
373                if (cur_seq_set[seq_id].none()) {
374                    LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
375                    return false;
376                }
377
378                if (pos < cur_seq_pos[seq_id]) {
379                    LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
380                    return false;
381                }
382            }
383        }
384    }
385
386    split_reset();
387
388    return true;
389}
390
391llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
392    const uint32_t n_tokens = n_seq_tokens*n_seqs;
393
394    clear();
395    split_reset();
396
397    auto udata = std::make_shared<llama_ubatch::data_t>();
398
399    udata->token     .resize(n_tokens);
400    udata->embd      .clear();
401    udata->pos       .resize(n_tokens);
402    udata->n_seq_id  .resize(n_tokens);
403    udata->seq_id    .resize(n_tokens);
404    udata->seq_id_unq.resize(0);
405    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
406    udata->output    .resize(n_tokens);
407
408    for (uint32_t s = 0; s < n_seqs; ++s) {
409        udata->seq_idx[s] = s;
410        udata->seq_id_unq.push_back(s);
411    }
412
413    llama_ubatch res {
414        /*.b_equal_seqs =*/ true,
415        /*.n_tokens     =*/ n_tokens,
416        /*.n_seq_tokens =*/ n_seq_tokens,
417        /*.n_seqs       =*/ n_seqs,
418        /*.n_seqs_unq   =*/ n_seqs,
419        /*.n_pos        =*/ n_pos_per_embd,
420
421        /*.token        =*/ udata->token.data(),
422        /*.embd         =*/ nullptr,
423        /*.pos          =*/ udata->pos.data(),
424        /*.n_seq_id     =*/ udata->n_seq_id.data(),
425        /*.seq_id       =*/ udata->seq_id.data(),
426        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
427        /*.seq_idx      =*/ udata->seq_idx.data(),
428        /*.output       =*/ udata->output.data(),
429        /*.data         =*/ std::move(udata),
430    };
431
432    return res;
433}
434
435const llama_batch & llama_batch_allocr::get_batch() const {
436    return batch;
437}
438
439uint32_t llama_batch_allocr::get_n_tokens() const {
440    return batch.n_tokens;
441}
442
443uint32_t llama_batch_allocr::get_n_outputs() const {
444    return n_outputs;
445}
446
447uint32_t llama_batch_allocr::get_n_used() const {
448    return n_used;
449}
450
451std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
452    return out_ids;
453}
454
455llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
456    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
457}
458
459llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
460    return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
461}
462
463void llama_batch_allocr::split_reset() {
464    out_ids.clear();
465
466    n_used = 0;
467
468    used.clear();
469    used.resize(get_n_tokens(), false);
470}
471
472llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
473    // find the first unused token
474    uint32_t cur_idx = 0;
475    while (cur_idx < used.size() && used[cur_idx]) {
476        ++cur_idx;
477    }
478
479    // we are done
480    if (cur_idx >= used.size()) {
481        return {};
482    }
483
484    std::vector<int32_t> idxs;
485
486    while (true) {
487        idxs.push_back(cur_idx);
488
489        used[cur_idx] = true;
490        ++n_used;
491
492        ++cur_idx;
493
494        if (cur_idx >= used.size()) {
495            break;
496        }
497
498        if (idxs.size() >= n_ubatch) {
499            break;
500        }
501    }
502
503    return ubatch_add(idxs, idxs.size(), false);
504}
505
506llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
507    if (sequential && has_cpl) {
508        LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch (you may need to use the -kvu flag)\n", __func__);
509
510        return {};
511    }
512
513    std::vector<seq_set_t> cur_seq_set;
514
515    llama_seq_id last_seq_id = -1;
516
517    // determine the non-overlapping sequence sets participating in this ubatch
518    for (int32_t i = 0; i < batch.n_tokens; ++i) {
519        if (used[i]) {
520            continue;
521        }
522
523        bool add = true;
524
525        for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
526            // no overlap with existing sequence sets:
527            if (!(cur_seq_set[s] & seq_set[i]).none()) {
528                add = false;
529                break;
530            }
531        }
532
533        // accept only increasing sequence ids
534        if (sequential) {
535            add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
536        }
537
538        if (add) {
539            cur_seq_set.push_back(seq_set[i]);
540
541            last_seq_id = batch.seq_id[i][0];
542
543            if (cur_seq_set.size() > n_ubatch) {
544                break;
545            }
546        }
547    }
548
549    const uint32_t n_seqs = cur_seq_set.size();
550
551    // we are done
552    if (n_seqs == 0) {
553        return {};
554    }
555
556    // the current batch index of each sequence set
557    std::vector<int32_t> cur_idx(n_seqs, 0);
558
559    for (uint32_t s = 0; s < n_seqs; ++s) {
560        while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
561            ++cur_idx[s];
562        }
563    }
564
565    // the list of batch indices for each sequence set
566    // at the end we will concat these to get the final ubatch
567    std::vector<idx_vec_t> idxs_per_seq(n_seqs);
568
569    while (true) {
570        // we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
571        //   if we haven't reached n_ubatch
572        bool can_expand = true;
573
574        for (uint32_t s = 0; s < n_seqs; ++s) {
575            if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
576                can_expand = false;
577                break;
578            }
579        }
580
581        if (!can_expand) {
582            break;
583        }
584
585        for (uint32_t s = 0; s < n_seqs; ++s) {
586            const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
587
588            idxs_per_seq[s].push_back(idx);
589
590            used[idx] = true;
591            ++n_used;
592
593            ++cur_idx[s];
594        }
595
596        if  ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
597            break;
598        }
599    }
600
601    // concat the per-sequence-set lists
602    std::vector<int32_t> idxs;
603
604    for (uint32_t s = 0; s < n_seqs; ++s) {
605        idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
606    }
607
608    return ubatch_add(idxs, n_seqs, true);
609}
610
611llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
612    // find the first unused token
613    uint32_t cur_idx = 0;
614    while (cur_idx < used.size() && used[cur_idx]) {
615        ++cur_idx;
616    }
617
618    // we are done
619    if (cur_idx >= used.size()) {
620        return {};
621    }
622
623    // this is the starting sequence set
624    // we allow adding tokens only if their sequence set is a subset of the current sequence set
625    auto cur_seq_set = seq_set[cur_idx];
626
627    std::vector<int32_t> idxs;
628
629    while (true) {
630        idxs.push_back(cur_idx);
631
632        used[cur_idx] = true;
633        ++n_used;
634
635        if (idxs.size() >= n_ubatch) {
636            break;
637        }
638
639        do {
640            ++cur_idx;
641        } while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
642
643        if (cur_idx == get_n_tokens()) {
644            break;
645        }
646
647        cur_seq_set = seq_set[cur_idx];
648    }
649
650    return ubatch_add(idxs, 1, true);
651}
652
653void llama_batch_allocr::clear() {
654    n_outputs = 0;
655
656    batch = {};
657
658    pos       .clear();
659    n_seq_id  .clear();
660    seq_id    .clear();
661    seq_id_unq.clear();
662    output    .clear();
663
664    for (auto & cur : seq_pos) {
665        cur.clear();
666    }
667
668    for (auto & cur : seq_cpl) {
669        std::fill(cur.begin(), cur.end(), false);
670    }
671
672    seq_set.clear();
673
674    seq_set_map.clear();
675
676    std::fill(seq_idx.begin(), seq_idx.end(), -1);
677}
678
679llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
680    const uint32_t n_tokens = idxs.size();
681
682    assert(n_tokens%n_seqs == 0);
683
684    auto udata = std::make_shared<llama_ubatch::data_t>();
685
686    const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
687    const int64_t n_pos_all  =              (int64_t) n_tokens*n_pos_per_embd;
688
689    udata->token     .resize(n_tokens);
690    udata->embd      .resize(n_embd_all);
691    udata->pos       .resize(n_pos_all);
692    udata->n_seq_id  .resize(n_tokens);
693    udata->seq_id    .resize(n_tokens);
694    udata->seq_id_unq.resize(0);
695    udata->seq_idx   .resize(LLAMA_MAX_SEQ, -1);
696    udata->output    .resize(n_tokens);
697
698    udata->seq_id_data.reserve(n_tokens);
699
700    seq_set_t seq_set_unq;
701
702    for (size_t i = 0; i < idxs.size(); ++i) {
703        if (batch.token) {
704            udata->token[i] = batch.token[idxs[i]];
705        }
706
707        if (batch.embd) {
708            memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
709        }
710
711        for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
712            // if we are using M-RoPE
713            //     if the current batch is text, we need to broadcast the same position across all RoPE sections
714            //     otherwise, the input batch is image embeddings, we copy the positions as-is
715            // if we are not using M-RoPE, there is only one position per token (this loop runs only once)
716            size_t src_off = batch.token ? 0 : j*batch.n_tokens;
717            udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
718        }
719
720        udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
721        udata->output[i]   = batch.logits[idxs[i]];
722
723        for (int s = 0; s < udata->n_seq_id[i]; ++s) {
724            const llama_seq_id seq_id = batch.seq_id[idxs[i]][s];
725
726            udata->seq_id_data.push_back(seq_id);
727            seq_set_unq.set(seq_id);
728        }
729
730        if (udata->output[i]) {
731            out_ids.push_back(idxs[i]);
732        }
733    }
734
735    llama_seq_id * seq_id_ptr = udata->seq_id_data.data();
736    for (size_t i = 0; i < idxs.size(); ++i) {
737        udata->seq_id[i] = seq_id_ptr;
738        seq_id_ptr += udata->n_seq_id[i];
739    }
740
741    for (uint32_t s = 0; s < n_seq_max; ++s) {
742        if (seq_set_unq.test(s)) {
743            udata->seq_idx[s] = udata->seq_id_unq.size();
744            udata->seq_id_unq.push_back(s);
745        }
746    }
747
748    llama_ubatch res {
749        /*.b_equal_seqs =*/ equal_seqs,
750        /*.n_tokens     =*/ n_tokens,
751        /*.n_seq_tokens =*/ n_tokens/n_seqs,
752        /*.n_seqs       =*/ n_seqs,
753        /*.n_seqs_unq   =*/ (uint32_t) udata->seq_id_unq.size(),
754        /*.n_pos        =*/ n_pos_per_embd,
755
756        /*.token        =*/ batch.token ? udata->token.data() : nullptr,
757        /*.embd         =*/ batch.embd ? udata->embd.data() : nullptr,
758        /*.pos          =*/ udata->pos.data(),
759        /*.n_seq_id     =*/ udata->n_seq_id.data(),
760        /*.seq_id       =*/ udata->seq_id.data(),
761        /*.seq_id_unq   =*/ udata->seq_id_unq.data(),
762        /*.seq_idx      =*/ udata->seq_idx.data(),
763        /*.output       =*/ udata->output.data(),
764        /*.data         =*/ std::move(udata),
765    };
766
767    if (debug > 0) {
768        LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__);
769
770        ubatch_print(res, debug);
771    }
772
773    return res;
774}
775
776void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
777    if (debug > 0) {
778        LLAMA_LOG_DEBUG("%s:   equal_seqs   = %d\n", __func__, ubatch.equal_seqs());
779        LLAMA_LOG_DEBUG("%s:   n_tokens     = %d\n", __func__, ubatch.n_tokens);
780        LLAMA_LOG_DEBUG("%s:   n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
781        LLAMA_LOG_DEBUG("%s:   n_seqs       = %d\n", __func__, ubatch.n_seqs);
782        LLAMA_LOG_DEBUG("%s:   n_seqs_unq   = %d\n", __func__, ubatch.n_seqs_unq);
783
784        std::stringstream ss_seq_id_unq;
785        std::stringstream ss_seq_idx;
786
787        ss_seq_id_unq << "[ ";
788        ss_seq_idx << "[";
789
790        for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
791            ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
792        }
793
794        for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
795            if (ubatch.seq_idx[s] >= 0) {
796                ss_seq_idx << ubatch.seq_idx[s]%10;
797            } else {
798                ss_seq_idx << ".";
799            }
800        }
801
802        ss_seq_id_unq << "]";
803        ss_seq_idx    << "]";
804
805        LLAMA_LOG_DEBUG("%s:   token      = %p\n", __func__, (void *) ubatch.token);
806        LLAMA_LOG_DEBUG("%s:   embd       = %p\n", __func__, (void *) ubatch.embd);
807        LLAMA_LOG_DEBUG("%s:   pos        = %p\n", __func__, (void *) ubatch.pos);
808        LLAMA_LOG_DEBUG("%s:   n_seq_id   = %p\n", __func__, (void *) ubatch.n_seq_id);
809        LLAMA_LOG_DEBUG("%s:   seq_id     = %p\n", __func__, (void *) ubatch.seq_id);
810        LLAMA_LOG_DEBUG("%s:   seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
811        LLAMA_LOG_DEBUG("%s:   seq_idx    = %s\n", __func__, ss_seq_idx.str().c_str());
812        LLAMA_LOG_DEBUG("%s:   output     = %p\n", __func__, (void *) ubatch.output);
813        LLAMA_LOG_DEBUG("%s:   n_outputs  = %d\n", __func__, n_outputs);
814
815        if (debug > 1) {
816            int seq_id_max = 0;
817            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
818                for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
819                    for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
820                        seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
821                    }
822                }
823            }
824            ++seq_id_max;
825
826            LLAMA_LOG_DEBUG("%s:   token     = [\n", __func__);
827            for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
828                std::vector<int8_t> seq_id(seq_id_max);
829
830                for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
831                    seq_id[ubatch.seq_id[i][s]] = 1;
832                }
833
834                std::stringstream ss;
835                for (int s = 0; s < seq_id_max; ++s) {
836                    if (seq_id[s]) {
837                        ss << s%10;
838                    } else {
839                        ss << ".";
840                    }
841                }
842
843                if (ubatch.token) {
844                    LLAMA_LOG_DEBUG("%s:  %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
845                            __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
846                            ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
847                } else {
848                    LLAMA_LOG_DEBUG("%s:  %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
849                            __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
850                }
851            }
852            LLAMA_LOG_DEBUG("%s:   ]\n", __func__);
853        }
854    }
855}
856
857//
858// interface implementation
859//
860
861struct llama_batch llama_batch_get_one(
862             llama_token * tokens,
863                 int32_t   n_tokens) {
864    return {
865        /*n_tokens =*/ n_tokens,
866        /*tokens   =*/ tokens,
867        /*embd     =*/ nullptr,
868        /*pos      =*/ nullptr,
869        /*n_seq_id =*/ nullptr,
870        /*seq_id   =*/ nullptr,
871        /*logits   =*/ nullptr,
872    };
873}
874
875struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
876    llama_batch batch = {
877        /*n_tokens =*/ 0,
878        /*tokens   =*/ nullptr,
879        /*embd     =*/ nullptr,
880        /*pos      =*/ nullptr,
881        /*n_seq_id =*/ nullptr,
882        /*seq_id   =*/ nullptr,
883        /*logits   =*/ nullptr,
884    };
885
886    if (embd) {
887        batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
888    } else {
889        batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc);
890    }
891
892    batch.pos      = (llama_pos *)     malloc(sizeof(llama_pos)      * n_tokens_alloc);
893    batch.n_seq_id = (int32_t *)       malloc(sizeof(int32_t)        * n_tokens_alloc);
894    batch.seq_id   = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1));
895    for (int i = 0; i < n_tokens_alloc; ++i) {
896        batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
897    }
898    batch.seq_id[n_tokens_alloc] = nullptr;
899
900    batch.logits   = (int8_t *)        malloc(sizeof(int8_t)         * n_tokens_alloc);
901
902    return batch;
903}
904
905void llama_batch_free(struct llama_batch batch) {
906    if (batch.token)    free(batch.token);
907    if (batch.embd)     free(batch.embd);
908    if (batch.pos)      free(batch.pos);
909    if (batch.n_seq_id) free(batch.n_seq_id);
910    if (batch.seq_id) {
911        for (int i = 0; batch.seq_id[i] != nullptr; ++i) {
912            free(batch.seq_id[i]);
913        }
914        free(batch.seq_id);
915    }
916    if (batch.logits)   free(batch.logits);
917}