1#define _USE_MATH_DEFINES // For M_PI on MSVC
   2
   3#include "arg.h"
   4#include "common.h"
   5#include "sampling.h"
   6#include "log.h"
   7#include "llama.h"
   8
   9#define JSON_ASSERT GGML_ASSERT
  10#include <nlohmann/json.hpp>
  11
  12#include <algorithm>
  13#include <cmath>
  14#include <cstdio>
  15#include <fstream>
  16#include <map>
  17#include <regex>
  18#include <string>
  19#include <thread>
  20#include <vector>
  21
  22using json = nlohmann::ordered_json;
  23
  24enum outetts_version {
  25    OUTETTS_V0_2,
  26    OUTETTS_V0_3,
  27};
  28
  29//
  30// Terminal utils
  31//
  32
  33#define SQR(X)    ((X) * (X))
  34#define UNCUBE(x) x < 48 ? 0 : x < 115 ? 1 : (x - 35) / 40
  35
  36/**
  37 * Quantizes 24-bit RGB to xterm256 code range [16,256).
  38 */
  39static int rgb2xterm256(int r, int g, int b) {
  40    unsigned char cube[] = {0, 0137, 0207, 0257, 0327, 0377};
  41    int av, ir, ig, ib, il, qr, qg, qb, ql;
  42    av = r * .299 + g * .587 + b * .114 + .5;
  43    ql = (il = av > 238 ? 23 : (av - 3) / 10) * 10 + 8;
  44    qr = cube[(ir = UNCUBE(r))];
  45    qg = cube[(ig = UNCUBE(g))];
  46    qb = cube[(ib = UNCUBE(b))];
  47    if (SQR(qr - r) + SQR(qg - g) + SQR(qb - b) <=
  48        SQR(ql - r) + SQR(ql - g) + SQR(ql - b))
  49        return ir * 36 + ig * 6 + ib + 020;
  50    return il + 0350;
  51}
  52
  53static std::string set_xterm256_foreground(int r, int g, int b) {
  54    int x = rgb2xterm256(r, g, b);
  55    std::ostringstream oss;
  56    oss << "\033[38;5;" << x << "m";
  57    return oss.str();
  58}
  59
  60const std::vector<std::string> k_colors = {
  61    set_xterm256_foreground(220,   5,  12),
  62    set_xterm256_foreground(232,  96,  28),
  63    set_xterm256_foreground(241, 147,  45),
  64    set_xterm256_foreground(246, 193,  65),
  65    set_xterm256_foreground(247, 240,  86),
  66    set_xterm256_foreground(144, 201, 135),
  67    set_xterm256_foreground( 78, 178, 101),
  68};
  69
  70static void print_usage(int, char ** argv) {
  71    LOG("\nexample usage:\n");
  72    LOG("\n    %s -m model.gguf -p \"Hello!\"\n", argv[0]);
  73    LOG("\n");
  74}
  75
  76struct wav_header {
  77    char riff[4] = {'R', 'I', 'F', 'F'};
  78    uint32_t chunk_size;
  79    char wave[4] = {'W', 'A', 'V', 'E'};
  80    char fmt[4] = {'f', 'm', 't', ' '};
  81    uint32_t fmt_chunk_size = 16;
  82    uint16_t audio_format = 1; // PCM
  83    uint16_t num_channels = 1; // Mono
  84    uint32_t sample_rate;
  85    uint32_t byte_rate;
  86    uint16_t block_align;
  87    uint16_t bits_per_sample = 16;
  88    char data[4] = {'d', 'a', 't', 'a'};
  89    uint32_t data_size;
  90};
  91
  92static bool save_wav16(const std::string & fname, const std::vector<float> & data, int sample_rate) {
  93    std::ofstream file(fname, std::ios::binary);
  94    if (!file) {
  95        LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str());
  96        return false;
  97    }
  98
  99    wav_header header;
 100    header.sample_rate = sample_rate;
 101    header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8);
 102    header.block_align = header.num_channels * (header.bits_per_sample / 8);
 103    header.data_size = data.size() * (header.bits_per_sample / 8);
 104    header.chunk_size = 36 + header.data_size;
 105
 106    file.write(reinterpret_cast<const char*>(&header), sizeof(header));
 107
 108    for (const auto & sample : data) {
 109        int16_t pcm_sample = static_cast<int16_t>(std::clamp(sample * 32767.0, -32768.0, 32767.0));
 110        file.write(reinterpret_cast<const char*>(&pcm_sample), sizeof(pcm_sample));
 111    }
 112
 113    return file.good();
 114}
 115
 116static void fill_hann_window(int length, bool periodic, float * output) {
 117    int offset = -1;
 118    if (periodic) {
 119        offset = 0;
 120    }
 121    for (int i = 0; i < length; i++) {
 122        output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset)));
 123    }
 124}
 125
 126// very poor-man fft
 127static void twiddle(float * real, float * imag, int k, int N) {
 128    float angle = 2 * M_PI * k / N;
 129    *real = cos(angle);
 130    *imag = sin(angle);
 131}
 132
 133static void irfft(int n, const float * inp_cplx, float * out_real) {
 134    int N = n / 2 + 1;
 135
 136    std::vector<float> real_input(N);
 137    std::vector<float> imag_input(N);
 138    for (int i = 0; i < N; ++i) {
 139        real_input[i] = inp_cplx[2 * i];
 140        imag_input[i] = inp_cplx[2 * i + 1];
 141    }
 142
 143    std::vector<float> real_output(n);
 144    std::vector<float> imag_output(n);
 145
 146    for (int k = 0; k < n; ++k) {
 147        real_output[k] = 0.0f;
 148        imag_output[k] = 0.0f;
 149        for (int m = 0; m < N; ++m) {
 150            float twiddle_real;
 151            float twiddle_imag;
 152
 153            twiddle(&twiddle_real, &twiddle_imag, k * m, n);
 154
 155            real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag;
 156            imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real;
 157        }
 158    }
 159
 160    for (int i = 0; i < n; ++i) {
 161        out_real[i] = real_output[i] / N;
 162    }
 163}
 164
 165//
 166//  y = torch.nn.functional.fold(
 167//       data, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
 168//  )[:, 0, 0, pad:-pad]
 169//
 170// data.shape =  torch.Size([1, 1280, 261])
 171// output_size =  84480
 172// win_length =  1280
 173// hop_length =  320
 174// pad =  480
 175//
 176static void fold(const std::vector<float> & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector<float> & output) {
 177    int64_t output_height = n_out;
 178    int64_t kernel_w = n_win;
 179    int64_t stride_w = n_hop;
 180    int64_t width    = n_out;
 181
 182    output.resize(width, 0.0f);
 183
 184    int64_t col_idx = 0;
 185    for (int64_t w_col = 0; w_col < width; ++w_col) {
 186        int64_t start = w_col * stride_w - n_pad;
 187        int64_t end   = start + kernel_w;
 188
 189        for (int64_t w_im = start; w_im < end; ++w_im) {
 190            if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) {
 191                output[w_im] += data[col_idx];
 192            }
 193            col_idx++;
 194        }
 195    }
 196
 197    output.resize(n_out - 2 * n_pad);
 198}
 199
 200// TODO: not optimized at all
 201static std::vector<float> embd_to_audio(
 202        const float * embd,
 203        const int n_codes,
 204        const int n_embd,
 205        const int n_thread) {
 206    const int n_fft = 1280;
 207    const int n_hop = 320;
 208    const int n_win = 1280;
 209    const int n_pad = (n_win - n_hop)/2;
 210    const int n_out = (n_codes - 1)*n_hop + n_win;
 211
 212    std::vector<float> hann(n_fft);
 213
 214    fill_hann_window(hann.size(), true, hann.data());
 215
 216    int n_spec = n_embd*n_codes;
 217
 218    std::vector<float> E (n_spec);
 219    std::vector<float> S (n_spec);
 220    std::vector<float> ST(n_spec);
 221
 222    for (int l = 0; l < n_codes; ++l) {
 223        for (int k = 0; k < n_embd; ++k) {
 224            E[k*n_codes + l] = embd[l*n_embd + k];
 225        }
 226    }
 227
 228    for (int k = 0; k < n_embd/2; ++k) {
 229        for (int l = 0; l < n_codes; ++l) {
 230            float mag = E[(k           )*n_codes + l];
 231            float phi = E[(k + n_embd/2)*n_codes + l];
 232
 233            mag = exp(mag);
 234
 235            if (mag > 1e2) {
 236                mag = 1e2;
 237            }
 238            S[2*(k*n_codes + l) + 0] = mag*cosf(phi);
 239            S[2*(k*n_codes + l) + 1] = mag*sinf(phi);
 240        }
 241    }
 242
 243    for (int l = 0; l < n_codes; ++l) {
 244        for (int k = 0; k < n_embd/2; ++k) {
 245            ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0];
 246            ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1];
 247        }
 248    }
 249
 250    std::vector<float> res  (n_codes*n_fft);
 251    std::vector<float> hann2(n_codes*n_fft);
 252
 253    std::vector<std::thread> workers(n_thread);
 254    for (int i = 0; i < n_thread; ++i) {
 255        workers[i] = std::thread([&, i]() {
 256            for (int l = i; l < n_codes; l += n_thread) {
 257                irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft);
 258                for (int j = 0; j < n_fft; ++j) {
 259                    res  [l*n_fft + j] *= hann[j];
 260                    hann2[l*n_fft + j]  = hann[j] * hann[j];
 261                }
 262            }
 263        });
 264    }
 265    for (int i = 0; i < n_thread; ++i) {
 266        workers[i].join();
 267    }
 268
 269    std::vector<float> audio;
 270    std::vector<float> env;
 271
 272    fold(res,   n_out, n_win, n_hop, n_pad, audio);
 273    fold(hann2, n_out, n_win, n_hop, n_pad, env); // TODO: can be done once
 274
 275    for (size_t i = 0; i < audio.size(); ++i) {
 276        audio[i] /= env[i];
 277    }
 278
 279    return audio;
 280}
 281
 282static const std::map<int, std::string> ones = {
 283    {0, "zero"}, {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"},
 284    {5, "five"}, {6, "six"}, {7, "seven"}, {8, "eight"}, {9, "nine"},
 285    {10, "ten"}, {11, "eleven"}, {12, "twelve"}, {13, "thirteen"}, {14, "fourteen"},
 286    {15, "fifteen"}, {16, "sixteen"}, {17, "seventeen"}, {18, "eighteen"}, {19, "nineteen"}
 287};
 288
 289static const std::map<int, std::string> tens = {
 290    {2, "twenty"}, {3, "thirty"}, {4, "forty"}, {5, "fifty"},
 291    {6, "sixty"}, {7, "seventy"}, {8, "eighty"}, {9, "ninety"}
 292};
 293
 294// Convert a number less than 1000 to words
 295static std::string convert_less_than_thousand(int num) {
 296    std::string result;
 297
 298    if (num >= 100) {
 299        result += ones.at(num / 100) + " hundred ";
 300        num %= 100;
 301    }
 302
 303    if (num >= 20) {
 304        result += tens.at(num / 10);
 305        if (num % 10 > 0) {
 306            result += "-" + ones.at(num % 10);
 307        }
 308    } else if (num > 0) {
 309        result += ones.at(num);
 310    }
 311
 312    return result;
 313}
 314
 315static std::string number_to_words(const std::string & number_str) {
 316    try {
 317        size_t decimal_pos = number_str.find('.');
 318        std::string integer_part = number_str.substr(0, decimal_pos);
 319
 320        int int_number = std::stoi(integer_part);
 321        std::string result;
 322
 323        if (int_number == 0) {
 324            result = "zero";
 325        } else {
 326            if (int_number >= 1000000000) {
 327                int billions = int_number / 1000000000;
 328                result += convert_less_than_thousand(billions) + " billion ";
 329                int_number %= 1000000000;
 330            }
 331
 332            if (int_number >= 1000000) {
 333                int millions = int_number / 1000000;
 334                result += convert_less_than_thousand(millions) + " million ";
 335                int_number %= 1000000;
 336            }
 337
 338            if (int_number >= 1000) {
 339                int thousands = int_number / 1000;
 340                result += convert_less_than_thousand(thousands) + " thousand ";
 341                int_number %= 1000;
 342            }
 343
 344            if (int_number > 0) {
 345                result += convert_less_than_thousand(int_number);
 346            }
 347        }
 348
 349        // Handle decimal part
 350        if (decimal_pos != std::string::npos) {
 351            result += " point";
 352            std::string decimal_part = number_str.substr(decimal_pos + 1);
 353            for (char digit : decimal_part) {
 354                result += " " + ones.at(digit - '0');
 355            }
 356        }
 357
 358        return result;
 359    } catch (const std::exception& e) {
 360        // Skip if fails
 361        return " ";
 362    }
 363}
 364
 365static std::string replace_numbers_with_words(const std::string & input_text) {
 366    std::regex number_pattern(R"(\d+(\.\d+)?)");
 367    std::string result;
 368    auto it = std::sregex_iterator(input_text.begin(), input_text.end(), number_pattern);
 369    auto end = std::sregex_iterator();
 370
 371    size_t last_pos = 0;
 372    for (std::sregex_iterator i = it; i != end; ++i) {
 373        const std::smatch& match = *i;
 374        result.append(input_text, last_pos, match.position() - last_pos);
 375        result.append(number_to_words(match.str()));
 376        last_pos = match.position() + match.length();
 377    }
 378    result.append(input_text, last_pos);
 379
 380    return result;
 381}
 382
 383// Based on: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/version/v1/prompt_processor.py#L39
 384static std::string process_text(const std::string & text, const outetts_version tts_version = OUTETTS_V0_2) {
 385
 386    // For now I skipped text romanization as I am unsure how to handle
 387    // uroman and MeCab implementations in C++
 388    // maybe something like https://github.com/anyascii/anyascii/ could work.
 389    // currently only English would be supported in this function
 390
 391    std::string processed_text = replace_numbers_with_words(text);
 392
 393    std::transform(processed_text.begin(), processed_text.end(),
 394                  processed_text.begin(), ::tolower);
 395
 396    std::regex special_chars(R"([-_/,\.\\])");
 397    processed_text = std::regex_replace(processed_text, special_chars, " ");
 398
 399    std::regex non_alpha(R"([^a-z\s])");
 400    processed_text = std::regex_replace(processed_text, non_alpha, "");
 401
 402    std::regex multiple_spaces(R"(\s+)");
 403    processed_text = std::regex_replace(processed_text, multiple_spaces, " ");
 404
 405    processed_text = std::regex_replace(processed_text, std::regex(R"(^\s+|\s+$)"), "");
 406
 407    /*
 408        Replace spaces with the separator token same as in line 365
 409
 410        for (auto & c : prompt_user) {
 411        if (c == ' ') {
 412            prompt_clean += "<|text_sep|>";
 413    */
 414    std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
 415    processed_text = std::regex_replace(processed_text, std::regex(R"(\s)"), separator);
 416
 417    return processed_text;
 418}
 419
 420static void prompt_add(llama_tokens & prompt, llama_token token) {
 421    prompt.push_back(token);
 422}
 423
 424static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
 425    prompt.insert(prompt.end(), tokens.begin(), tokens.end());
 426}
 427
 428static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) {
 429    auto tmp = common_tokenize(vocab, txt, add_special, parse_special);
 430    prompt_add(prompt, tmp);
 431}
 432
 433static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
 434    prompt.clear();
 435
 436    prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
 437}
 438
 439static std::vector<llama_token> prepare_guide_tokens(const llama_vocab * vocab, const std::string & str, const outetts_version tts_version = OUTETTS_V0_2) {
 440    const std::string& delimiter = (tts_version == OUTETTS_V0_3 ? "<|space|>" : "<|text_sep|>");
 441
 442    std::vector<llama_token> result;
 443    size_t start = 0;
 444    size_t end = str.find(delimiter);
 445
 446    //first token is always a newline, as it was not previously added
 447    result.push_back(common_tokenize(vocab, "\n", false, true)[0]);
 448
 449    while (end != std::string::npos) {
 450        std::string current_word = str.substr(start, end - start);
 451        auto tmp = common_tokenize(vocab, current_word, false, true);
 452        result.push_back(tmp[0]);
 453        start = end + delimiter.length();
 454        end = str.find(delimiter, start);
 455    }
 456
 457    // Add the last part
 458    std::string current_word = str.substr(start);
 459    auto tmp = common_tokenize(vocab, current_word, false, true);
 460    if (tmp.size() > 0) {
 461        result.push_back(tmp[0]);
 462    }
 463    return result;
 464}
 465
 466static json speaker_from_file(const std::string & speaker_file) {
 467    std::ifstream file(speaker_file);
 468    if (!file) {
 469        LOG_ERR("%s: Failed to open file '%s' for reading\n", __func__, speaker_file.c_str());
 470        return json();
 471    }
 472
 473    json speaker = json::parse(file);
 474    return speaker;
 475}
 476
 477static outetts_version get_tts_version(llama_model *model, json speaker = json::object()) {
 478    if (speaker.contains("version")) {
 479        std::string version = speaker["version"].get<std::string>();
 480        if (version == "0.2") {
 481            return OUTETTS_V0_2;
 482        } else if (version == "0.3") {
 483            return OUTETTS_V0_3;
 484        } else {
 485            LOG_ERR("%s: Unsupported speaker version '%s'\n", __func__, version.c_str());
 486        }
 487    }
 488
 489    // Also could get version from model itself
 490    const char *chat_template = llama_model_chat_template(model, nullptr);
 491    if (chat_template && std::string(chat_template) == "outetts-0.3") {
 492        return OUTETTS_V0_3;
 493    }
 494
 495    // Use 0.2 as the default version
 496    return OUTETTS_V0_2;
 497}
 498
 499static std::string audio_text_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
 500    std::string audio_text = "<|text_start|>";
 501
 502    if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
 503        std::string separator = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|text_sep|>";
 504        for (const auto &word : speaker["words"]) {
 505            audio_text += word["word"].get<std::string>() + separator;
 506        }
 507    }
 508
 509    return audio_text;
 510}
 511
 512static std::string audio_data_from_speaker(json speaker, const outetts_version tts_version = OUTETTS_V0_2) {
 513    std::string audio_data = "<|audio_start|>\n";
 514
 515    if (tts_version == OUTETTS_V0_2 || tts_version == OUTETTS_V0_3) {
 516        std::string code_start = (tts_version == OUTETTS_V0_3) ? "" : "<|code_start|>";
 517        std::string code_end = (tts_version == OUTETTS_V0_3) ? "<|space|>" : "<|code_end|>";
 518        for (const auto &word : speaker["words"]) {
 519            std::string word_text = word["word"].get<std::string>();
 520            double duration = word["duration"].get<double>();
 521            std::vector<int> codes = word["codes"].get<std::vector<int>>();
 522
 523            // Create the audio output entry
 524            std::ostringstream word_entry;
 525            word_entry << word_text << "<|t_" << std::fixed << std::setprecision(2)
 526                       << duration << "|>" + code_start;
 527            for (const auto &Code : codes) {
 528                word_entry << "<|" << Code << "|>";
 529            }
 530            word_entry << code_end << "\n";
 531            audio_data += word_entry.str();
 532        }
 533    }
 534
 535    return audio_data;
 536}
 537
 538int main(int argc, char ** argv) {
 539    common_params params;
 540
 541    params.out_file = "output.wav";
 542    params.prompt = "";
 543
 544    params.n_predict = 4096;
 545    params.n_batch   = 8192;
 546    params.n_ctx     = 8192;
 547
 548    params.sampling.top_k = 4;
 549    params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, };
 550
 551    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) {
 552        return 1;
 553    }
 554
 555    const int n_parallel = params.n_parallel;
 556    const int n_predict  = params.n_predict;
 557
 558    common_init();
 559
 560    // init LLM
 561
 562    llama_backend_init();
 563    llama_numa_init(params.numa);
 564
 565    llama_model * model_ttc = NULL; // text-to-codes
 566    llama_model * model_cts = NULL; // codes-to-speech
 567
 568    llama_context * ctx_ttc = NULL;
 569    llama_context * ctx_cts = NULL;
 570
 571    auto llama_init_ttc = common_init_from_params(params);
 572
 573    model_ttc = llama_init_ttc->model();
 574    ctx_ttc   = llama_init_ttc->context();
 575
 576    if (model_ttc == nullptr || ctx_ttc == nullptr) {
 577        return ENOENT;
 578    }
 579
 580    const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
 581
 582    params.model = params.vocoder.model;
 583    params.embedding = true;
 584    params.n_ubatch = params.n_batch;
 585
 586    auto llama_init_cts = common_init_from_params(params);
 587
 588    model_cts = llama_init_cts->model();
 589    ctx_cts   = llama_init_cts->context();
 590
 591    if (model_cts == nullptr || ctx_cts == nullptr) {
 592        return ENOENT;
 593    }
 594
 595    std::vector<common_sampler *> smpl(n_parallel);
 596    for (int i = 0; i < n_parallel; ++i) {
 597        params.sampling.no_perf = (i != 0);
 598        params.sampling.seed = params.sampling.seed + 1;
 599
 600        smpl[i] = common_sampler_init(model_ttc, params.sampling);
 601    }
 602
 603    LOG_INF("sampler seed: %u\n",     common_sampler_get_seed(smpl[0]));
 604    LOG_INF("sampler params: \n%s\n", params.sampling.print().c_str());
 605    LOG_INF("sampler chain: %s\n",    common_sampler_print(smpl[0]).c_str());
 606
 607    LOG_INF("%s: loading done\n", __func__);
 608
 609    const auto t_main_start = ggml_time_us();
 610
 611    std::vector<llama_token> codes;
 612    std::vector<llama_token> guide_tokens;
 613
 614    // the default speaker profile is from: https://github.com/edwko/OuteTTS/blob/main/outetts/version/v1/default_speakers/en_male_1.json
 615    std::string audio_text = "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>";
 616    std::string audio_data = R"(<|audio_start|>
 617the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|>
 618overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|>
 619package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|>
 620from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|>
 621just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|>
 622two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|>
 623people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|>
 624is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|>
 625pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|>
 626remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|>
 627sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|>
 628i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|>
 629have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|>
 630some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|>
 631critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|>
 632about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|>
 633some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|>
 634of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|>
 635the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|>
 636gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|>
 637aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|>
 638but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|>
 639its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|>
 640still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|>
 641really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|>
 642enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|>
 643and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|>
 644it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|>
 645looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
 646lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
 647
 648    // audio data for 0.3 version
 649    outetts_version tts_version = get_tts_version(model_ttc);
 650    if (tts_version == OUTETTS_V0_3) {
 651        audio_text = std::regex_replace(audio_text, std::regex(R"(<\|text_sep\|>)"), "<|space|>");
 652        audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_start\|>)"), "");
 653        audio_data = std::regex_replace(audio_data, std::regex(R"(<\|code_end\|>)"), "<|space|>");
 654    }
 655
 656    // load speaker if given
 657    if (!params.vocoder.speaker_file.empty()) {
 658        LOG_INF("%s: loading speaker ..\n", __func__);
 659        json speaker = speaker_from_file(params.vocoder.speaker_file);
 660        if (speaker.empty()) {
 661            LOG_ERR("%s: Failed to load speaker file '%s'\n", __func__, params.vocoder.speaker_file.c_str());
 662            return 1;
 663        }
 664        audio_text = audio_text_from_speaker(speaker, tts_version);
 665        audio_data = audio_data_from_speaker(speaker, tts_version);
 666    }
 667
 668    // process prompt and generate voice codes
 669    {
 670        LOG_INF("%s: constructing prompt ..\n", __func__);
 671
 672        std::vector<llama_token> prompt_inp;
 673
 674        prompt_init(prompt_inp, vocab);
 675
 676        prompt_add(prompt_inp, vocab, audio_text, false, true);
 677
 678        // convert the input text into the necessary format expected by OuteTTS
 679        {
 680            std::string prompt_clean = process_text(params.prompt, tts_version);
 681            if (params.vocoder.use_guide_tokens) {
 682                guide_tokens = prepare_guide_tokens(vocab, prompt_clean, tts_version);
 683            }
 684
 685            LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
 686
 687            prompt_add(prompt_inp, vocab, prompt_clean, false, true);
 688        }
 689
 690        prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
 691
 692        if (!params.vocoder.speaker_file.empty()) {
 693            prompt_add(prompt_inp, vocab, audio_data, false, true);
 694        } else {
 695            // disabled to save time on tokenizing each time
 696#if 1
 697            const std::string voice_data = audio_data;
 698
 699            auto tmp = common_tokenize(vocab, voice_data, false, true);
 700
 701            std::ostringstream tokens_oss;
 702            for (size_t i = 0; i < tmp.size(); ++i) {
 703                tokens_oss << tmp[i] << ", ";
 704            }
 705            LOG_INF("\n\n%s: llama tokens: %s\n\n", __func__, tokens_oss.str().c_str());
 706
 707            prompt_add(prompt_inp, tmp);
 708#else
 709            prompt_add(prompt_inp, llama_tokens {
 710                151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585,
 711                152460, 153375, 151670, 198, 74455, 155808, 151669, 151799,
 712                151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470,
 713                151970, 153413, 152419, 153334, 153289, 153374, 153199, 152040,
 714                153260, 152721, 152680, 153297, 152419, 153248, 152400, 152691,
 715                153368, 153437, 151670, 198, 1722, 155828, 151669, 152607,
 716                152256, 152991, 152299, 152688, 153163, 153016, 152789, 153198,
 717                152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207,
 718                152461, 153321, 153309, 151750, 152137, 153340, 152573, 152267,
 719                153347, 151789, 152681, 153339, 151992, 152512, 151751, 152179,
 720                153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904,
 721                152311, 151670, 198, 1499, 155791, 151669, 152276, 152454,
 722                153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226,
 723                153043, 152325, 153267, 152622, 151670, 198, 4250, 155797,
 724                151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271,
 725                152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213,
 726                152112, 153204, 151722, 152542, 151670, 198, 19789, 155796,
 727                151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002,
 728                152191, 151734, 152312, 152810, 152237, 153224, 153169, 153224,
 729                152244, 153387, 153404, 151670, 198, 16069, 155811, 151669,
 730                152265, 151946, 151808, 152412, 152363, 152305, 153156, 152733,
 731                152810, 153157, 152016, 152100, 152069, 153234, 152317, 152589,
 732                152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504,
 733                153376, 152272, 152433, 152325, 151941, 151670, 198, 285,
 734                155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381,
 735                152474, 152680, 152157, 153255, 152324, 151682, 151670, 198,
 736                32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682,
 737                152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488,
 738                153070, 151883, 152890, 152489, 153144, 153375, 152358, 151685,
 739                152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669,
 740                151902, 152720, 153377, 152027, 152378, 152821, 153207, 153459,
 741                153028, 153068, 152507, 153255, 152158, 152921, 151958, 152609,
 742                152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470,
 743                152606, 152162, 152186, 153071, 152244, 153118, 153375, 153018,
 744                152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736,
 745                153380, 153502, 152702, 152115, 153181, 152735, 153277, 153457,
 746                152393, 153112, 152595, 151670, 198, 19098, 155808, 151669,
 747                152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239,
 748                153163, 152922, 153402, 152034, 152591, 153438, 152215, 151673,
 749                152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482,
 750                152718, 152862, 153347, 151670, 198, 72, 155780, 151669, 151795,
 751                152111, 152746, 152377, 153471, 152309, 151670, 198, 19016,
 752                155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701,
 753                152939, 152536, 152091, 151815, 152733, 151672, 151670, 198,
 754                14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042,
 755                153504, 152589, 153333, 151839, 151941, 153038, 153180, 151670,
 756                198, 36996, 8303, 155832, 151669, 152231, 152256, 152835,
 757                152801, 152985, 153400, 152393, 152818, 152765, 152249, 152600,
 758                151699, 152302, 152752, 153018, 153009, 151992, 153054, 152847,
 759                153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458,
 760                152048, 152757, 152428, 153195, 151906, 153006, 153178, 153250,
 761                152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418,
 762                152228, 152733, 151670, 198, 9096, 155801, 151669, 151698,
 763                153321, 152217, 153039, 152935, 153400, 152122, 152531, 153106,
 764                152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851,
 765                152901, 152885, 152594, 153446, 153080, 151670, 198, 14689,
 766                155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191,
 767                151673, 151690, 151698, 152714, 152846, 152981, 153171, 153384,
 768                153364, 153188, 153246, 151670, 198, 1055, 155779, 151669,
 769                151869, 152388, 152711, 153334, 151736, 151670, 198, 1782,
 770                155780, 151669, 153483, 153240, 152241, 152558, 152697, 153046,
 771                151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605,
 772                153034, 153434, 153372, 153347, 151887, 152453, 152758, 152133,
 773                152510, 152694, 152431, 152321, 153088, 152676, 152223, 152581,
 774                152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032,
 775                152903, 152859, 152989, 151748, 152669, 152661, 152650, 152409,
 776                151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469,
 777                152988, 152894, 151819, 152391, 153019, 152058, 153062, 153230,
 778                151826, 152112, 152306, 152264, 152769, 153390, 152384, 152435,
 779                152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540,
 780                151919, 151893, 152558, 152817, 152946, 152956, 152129, 152715,
 781                153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450,
 782                151670, 198, 8088, 155792, 151669, 152452, 153497, 153353,
 783                152679, 152533, 152382, 152374, 152611, 153341, 153163, 152285,
 784                153411, 152495, 153141, 152320, 151670, 198, 1199, 155781,
 785                151669, 151764, 152360, 153295, 152634, 153342, 152199, 152271,
 786                151670, 198, 43366, 155799, 151669, 152308, 151682, 152889,
 787                152016, 152385, 152629, 152495, 151826, 153321, 152958, 152180,
 788                151886, 153432, 152922, 152128, 153024, 153040, 152593, 152287,
 789                151677, 151670, 198, 53660, 155808, 151669, 151727, 152092,
 790                152680, 153331, 151699, 152316, 152938, 152289, 152433, 153384,
 791                151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691,
 792                152489, 151941, 152049, 152034, 153053, 152179, 153160, 151676,
 793                153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350,
 794                152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234,
 795                153135, 152291, 153235, 152143, 152583, 152402, 153483, 152678,
 796                152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825,
 797                152548, 153442, 152109, 152659, 153325, 152781, 152570, 152957,
 798                151752, 152265, 153381, 152515, 151670, 198, 437, 155787,
 799                151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174,
 800                151792, 153409, 153327, 152990, 151670, 198, 275, 155781,
 801                151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974,
 802                151670, 198, 94273, 155799, 151669, 152953, 152938, 153427,
 803                152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331,
 804                152257, 152987, 152777, 153448, 152408, 151696, 152408, 152326,
 805                152699, 151670, 198, 385, 16239, 155828, 151669, 152306, 152268,
 806                153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110,
 807                152918, 152923, 152467, 152331, 153053, 153330, 151889, 153444,
 808                152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751,
 809                152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499,
 810                152109, 152255, 151739, 152267, 152759, 153318, 153165, 153349,
 811                151670,});
 812#endif
 813        }
 814
 815        // print the prompt token-by-token
 816
 817        LOG("\n");
 818
 819        for (auto id : prompt_inp) {
 820            LOG("%s", common_token_to_piece(ctx_ttc, id).c_str());
 821        }
 822
 823        LOG_INF("%s: prompt size: %d\n", __func__, (int) prompt_inp.size());
 824
 825        LOG("\n");
 826
 827        // create a llama_batch
 828        // we use this object to submit token data for decoding
 829        llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel);
 830
 831        std::vector<llama_seq_id> seq_ids(n_parallel, 0);
 832        for (int32_t i = 0; i < n_parallel; ++i) {
 833            seq_ids[i] = i;
 834        }
 835
 836        // evaluate the initial prompt
 837        for (size_t i = 0; i < prompt_inp.size(); ++i) {
 838            common_batch_add(batch, prompt_inp[i], i, seq_ids, false);
 839        }
 840        GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size());
 841
 842        // llama_decode will output logits only for the last token of the prompt
 843        batch.logits[batch.n_tokens - 1] = true;
 844
 845        if (llama_decode(ctx_ttc, batch) != 0) {
 846            LOG_ERR("%s: llama_decode() failed\n", __func__);
 847            return 1;
 848        }
 849
 850        if (n_parallel > 1) {
 851            LOG_INF("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
 852        }
 853
 854        llama_synchronize(ctx_ttc);
 855
 856        LOG_INF("%s: time for prompt: %.3f ms\n\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
 857
 858        const auto t_dec_start = ggml_time_us();
 859
 860        // main loop
 861
 862        // remember the batch index of the last token for each parallel sequence
 863        // we need this to determine which logits to sample from
 864        std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
 865
 866        int n_past   = batch.n_tokens;
 867        int n_decode = 0;
 868
 869        bool next_token_uses_guide_token = true;
 870
 871        while (n_decode <= n_predict) {
 872            // prepare the next batch
 873            common_batch_clear(batch);
 874
 875            // sample the next token for each parallel sequence / stream
 876            for (int32_t i = 0; i < n_parallel; ++i) {
 877                if (i_batch[i] < 0) {
 878                    // the stream has already finished
 879                    continue;
 880                }
 881
 882                llama_token new_token_id = common_sampler_sample(smpl[i], ctx_ttc, i_batch[i]);
 883
 884                //guide tokens help prevent hallucinations by forcing the TTS to use the correct word
 885                if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
 886                    llama_token guide_token = guide_tokens[0];
 887                    guide_tokens.erase(guide_tokens.begin());
 888                    new_token_id = guide_token; //ensure correct word fragment is used
 889                }
 890
 891                //this is the token id that always precedes a new word
 892                next_token_uses_guide_token = (new_token_id == 198);
 893
 894                common_sampler_accept(smpl[i], new_token_id, true);
 895
 896                codes.push_back(new_token_id);
 897
 898                const auto * cands = common_sampler_get_candidates(smpl[i], false);
 899
 900                // is it an end of generation? -> mark the stream as finished
 901                if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) {
 902                    std::string reason;
 903                    if (llama_vocab_is_eog(vocab, new_token_id)) {
 904                        reason = "eos";
 905                    } else {
 906                        reason = "n_predict";
 907                    }
 908
 909                    i_batch[i] = -1;
 910
 911                    LOG("\n");
 912                    if (n_parallel > 1) {
 913                        LOG_CNT("\n");
 914                        LOG_INF("%s: stream %d finished at n_past = %d, reason = '%s'\n", __func__, i, n_past, reason.c_str());
 915                    }
 916
 917                    continue;
 918                }
 919
 920                {
 921                    const float p = cands->data[cands->selected].p;
 922
 923                    const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) ((3*p)*float(k_colors.size()))));
 924
 925                    LOG_CNT("%s%d%s", k_colors[col].c_str(), i, "\033[0m");
 926                    //LOG_CNT("%d", i);
 927                }
 928
 929                i_batch[i] = batch.n_tokens;
 930
 931                // push this new token for next evaluation
 932                common_batch_add(batch, new_token_id, n_past, { i }, true);
 933            }
 934
 935            // all streams are finished
 936            if (batch.n_tokens == 0) {
 937                break;
 938            }
 939
 940            n_decode += 1;
 941            n_past += 1;
 942
 943            // evaluate the current batch with the transformer model
 944            if (llama_decode(ctx_ttc, batch)) {
 945                LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
 946                return 1;
 947            }
 948        }
 949
 950        llama_batch_free(batch);
 951
 952        LOG("\n");
 953        LOG_INF("%s: time for decoder:       %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f);
 954    }
 955
 956    common_perf_print(ctx_ttc, smpl[0]);
 957
 958    //std::vector<llama_token> codes = {198, 88225, 155856, 151669, 152205,
 959    //    153064, 152537, 153421, 153209, 152524, 151689, 152993, 152438, 152695,
 960    //    153091, 152945, 152829, 152534, 152934, 153020, 151997, 152263, 153010,
 961    //    153146, 152399, 153208, 152496, 151793, 152848, 152263, 152571, 153286,
 962    //    152227, 153300, 152934, 152263, 153208, 152263, 152965, 152430, 152296,
 963    //    153146, 152920, 152376, 152556, 153363, 151775, 152044, 152972, 152690,
 964    //    153379, 152368, 152233, 153422, 152490, 151996, 152022, 151694, 152061,
 965    //    153238, 152539, 153356, 152640, 153021, 153123, 151962, 153094, 151670,
 966    //    198, 20339, 13189, 155824, 151669, 152070, 152007, 152910, 151683,
 967    //    152000, 152373, 152760, 152046, 151735, 152334, 152394, 153073, 152908,
 968    //    151856, 151953, 153247, 153293, 151903, 153480, 153168, 152478, 153359,
 969    //    153429, 151905, 151678, 152567, 152411, 152165, 152556, 153075, 153424,
 970    //    151993, 152999, 153078, 152151, 152088, 153389, 152484, 151874, 151670,
 971    //    198, 285, 155784, 151669, 152226, 152126, 152638, 153215, 151729,
 972    //    152959, 153479, 153059, 151838, 151670, 198, 1782, 155783, 151669,
 973    //    153288, 153055, 153314, 152497, 152962, 152741, 152076, 153253, 151670,
 974    //    198, 471, 16488, 155825, 151669, 152060, 152916, 151893, 153469, 152501,
 975    //    152080, 152743, 151932, 153161, 152096, 152761, 152698, 153401, 153242,
 976    //    153336, 152441, 152838, 153467, 152706, 153496, 153310, 152422, 153360,
 977    //    153115, 152763, 151998, 152373, 153450, 152554, 151968, 153323, 152055,
 978    //    152468, 153111, 153358, 152813, 152010, 151770, 152823, 152960, 151670,
 979    //    198, 22627, 155823, 151669, 152814, 152366, 153484, 152931, 153441,
 980    //    152164, 152877, 152915, 153463, 151692, 152911, 152747, 152776, 151831,
 981    //    153449, 151882, 152975, 152031, 152513, 153150, 152448, 152667, 153133,
 982    //    153189, 152619, 153466, 152054, 152106, 153119, 152277, 152439, 153109,
 983    //    152997, 152141, 153154, 153256, 153311, 151922, 151670, 198, 1055,
 984    //    155781, 151669, 152633, 151850, 153060, 153270, 152560, 153348, 152729,
 985    //    151670, 198, 25312, 155803, 151669, 152521, 153403, 152561, 153337,
 986    //    153383, 152199, 153493, 153326, 151830, 152254, 152248, 152349, 152153,
 987    //    153007, 151823, 153037, 152575, 152457, 152406, 152592, 153116, 153365,
 988    //    153456, 151670, 198, 88225, 155817, 151669, 153271, 151925, 152218,
 989    //    152418, 152253, 153140, 151903, 153151, 152626, 152338, 152647, 153464,
 990    //    152785, 152768, 151711, 152037, 152033, 151804, 152216, 151701, 151855,
 991    //    152348, 152995, 152955, 152905, 152342, 152340, 153391, 153453, 152418,
 992    //    153415, 151990, 153083, 152884, 151670, 198, 151668, 198, 151645};
 993
 994    {
 995        const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
 996
 997        LOG("\n");
 998        LOG_INF("codes: '%s'\n", inp_txt.c_str());
 999        LOG_INF("%s: codes size: %d\n", __func__, (int) codes.size());
1000    }
1001
1002    // remove all non-audio tokens (i.e. < 151672 || > 155772)
1003    codes.erase(std::remove_if(codes.begin(), codes.end(), [](llama_token t) { return t < 151672 || t > 155772; }), codes.end());
1004
1005    {
1006        const std::string inp_txt = common_detokenize(ctx_ttc, codes, true);
1007        LOG_INF("codes audio: '%s'\n", inp_txt.c_str());
1008        LOG_INF("%s: codes audio size: %d\n", __func__, (int) codes.size());
1009    }
1010
1011    for (auto & token : codes) {
1012        token -= 151672;
1013    }
1014
1015    const auto t_voc_start = ggml_time_us();
1016
1017    const int n_codes = codes.size();
1018
1019    llama_batch batch = llama_batch_init(n_codes, 0, 1);
1020
1021    for (size_t i = 0; i < codes.size(); ++i) {
1022        common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits?
1023    }
1024    GGML_ASSERT(batch.n_tokens == n_codes);
1025
1026    if (llama_encode(ctx_cts, batch) != 0) {
1027        LOG_ERR("%s: llama_encode() failed\n", __func__);
1028        return 1;
1029    }
1030
1031    llama_synchronize(ctx_cts);
1032
1033    LOG_INF("%s: time for vocoder:      %.3f ms\n", __func__, (ggml_time_us() - t_voc_start) / 1000.0f);
1034
1035    const auto t_spec_start = ggml_time_us();
1036
1037#if 1
1038    // spectral operations
1039    const int n_embd = llama_model_n_embd_out(model_cts);
1040    const float * embd = llama_get_embeddings(ctx_cts);
1041
1042    auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);
1043
1044#else
1045    // read the spectrogram from a file for debugging purposes
1046    std::vector<float> audio;
1047    {
1048        std::ifstream fin("out.bin", std::ios::binary);
1049        if (!fin) {
1050            LOG_ERR("%s: failed to open file '%s'\n", __func__, "out.bin");
1051            return 1;
1052        }
1053
1054        std::vector<float> embd;
1055
1056        int n_codes;
1057        int n_embd;
1058
1059        fin.read(reinterpret_cast<char *>(&n_codes), sizeof(int));
1060        fin.read(reinterpret_cast<char *>(&n_embd), sizeof(int));
1061
1062        embd.resize(n_codes * n_embd);
1063        fin.read(reinterpret_cast<char *>(embd.data()), n_codes * n_embd * sizeof(float));
1064        fin.close();
1065
1066        LOG_INF("%s: n_codes: %d, n_embd: %d\n", __func__, n_codes, n_embd);
1067
1068        audio = embd_to_audio(embd.data(), n_codes, n_embd, params.cpuparams.n_threads);
1069    }
1070#endif
1071
1072    const int n_sr = 24000; // sampling rate
1073
1074    // zero out first 0.25 seconds
1075    for (int i = 0; i < 24000/4; ++i) {
1076        audio[i] = 0.0f;
1077    }
1078
1079    LOG_INF("%s: time for spectral ops: %.3f ms\n", __func__, (ggml_time_us() - t_spec_start) / 1000.0f);
1080    LOG_INF("%s: total time:            %.3f ms\n", __func__, (ggml_time_us() - t_main_start) / 1000.0f);
1081
1082    int retval = 0;
1083
1084    if (save_wav16(params.out_file, audio, n_sr)) {
1085        LOG_INF("%s: audio written to file '%s'\n", __func__, params.out_file.c_str());
1086    } else {
1087        retval = ENOENT;
1088    }
1089
1090    llama_backend_free();
1091
1092    return retval;
1093}