1#include "arg.h"
  2#include "chat.h"
  3#include "common.h"
  4#include "llama.h"
  5#include "log.h"
  6
  7#include <limits.h>
  8
  9#include <algorithm>
 10#include <cmath>
 11#include <cstring>
 12#include <limits>
 13#include <random>
 14#include <string>
 15#include <vector>
 16
 17enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
 18
 19// Unified transfer scheduling methods
 20enum transfer_schedule {
 21    TIMESTEP_BASED = 0,  // Dream-style: (1.0 - s/t) * remaining
 22    BLOCK_BASED    = 1,  // LLaDA-style: process in blocks with get_num_transfer_tokens
 23};
 24
 25typedef bool (*diffusion_step_callback_t)(int32_t             step,
 26                                          int32_t             total_steps,
 27                                          const llama_token * tokens,
 28                                          int32_t             n_tokens,
 29                                          void *              user_data);
 30
 31struct diffusion_params {
 32    int32_t                   steps                   = 0;
 33    float                     temperature             = 0;
 34    llama_token               mask_token_id           = LLAMA_TOKEN_NULL;
 35    diffusion_step_callback_t step_callback           = nullptr;
 36    void *                    step_callback_user_data = nullptr;
 37    int32_t                   seed                    = 0;
 38    bool                      visual_mode             = false;
 39    bool                      shift_logits            = false;  // Shift logits by -1 after decode
 40
 41    float   top_p = 0.;
 42    int32_t top_k = 0.;
 43
 44    diffusion_algorithm algorithm = CONFIDENCE_BASED;
 45    transfer_schedule   schedule  = TIMESTEP_BASED;
 46
 47    float   cfg_scale        = 0.;     // Config scale for classifier-free guidance
 48    float   eps              = 0.;     // Timestep scheduling
 49    int32_t block_length     = 0;      // Block size (for block scheduling)
 50    float   alg_temp         = 0;      // algorithm temperature (0.0 = deterministic)
 51    bool    add_gumbel_noise = false;  // Add gumbel noise to the logits if temp > 0.0
 52
 53    int32_t max_length = 0;            // Maximum sequence length
 54};
 55
 56struct callback_data {
 57    diffusion_params *  diff_params;
 58    const llama_vocab * vocab;
 59    int32_t             n_input;
 60};
 61
 62static float calculate_confidence(const llama_token_data_array & cur_p,
 63                                  diffusion_algorithm            algorithm,
 64                                  std::mt19937 &                 rng) {
 65    switch (algorithm) {
 66        case CONFIDENCE_BASED:
 67            return cur_p.data[cur_p.selected].p;  // Selected token probability
 68
 69        case ENTROPY_BASED:
 70            {
 71                float       entropy = 0.0f;
 72                const float epsilon = 1e-10f;
 73                for (size_t i = 0; i < cur_p.size; i++) {
 74                    float prob = cur_p.data[i].p;
 75                    entropy += prob * logf(prob + epsilon);
 76                }
 77                return -entropy;  // Higher entropy = lower confidence
 78            }
 79
 80        case MARGIN_BASED:
 81            return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
 82
 83        case RANDOM:
 84            {
 85                std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
 86                return uniform(rng);  // Random confidence
 87            }
 88
 89        case ORIGIN:
 90            return cur_p.data[cur_p.selected].p;
 91
 92        default:
 93            return 0.0f;
 94    }
 95}
 96
 97// Unified transfer count calculation function
 98static int32_t calculate_transfer_count(int32_t                      step,
 99                                        int32_t                      total_steps,
100                                        int32_t                      remaining_masked,
101                                        transfer_schedule            schedule,
102                                        float                        eps,
103                                        const std::vector<int32_t> & num_transfer_tokens = {}) {
104    switch (schedule) {
105        case TIMESTEP_BASED:
106            {
107                float t          = 1.0f - (float) step / total_steps * (1.0f - eps);
108                float s          = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
109                float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
110                return (int32_t) (remaining_masked * p_transfer);
111            }
112
113        case BLOCK_BASED:
114            if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
115                return num_transfer_tokens[step];
116            }
117            return remaining_masked / (total_steps - step);  // Fallback
118
119        default:
120            return remaining_masked / (total_steps - step);
121    }
122}
123
124static bool diffusion_step_callback(int32_t             step,
125                                    int32_t             total_steps,
126                                    const llama_token * tokens,
127                                    int32_t             n_tokens,
128                                    void *              user_data) {
129    (void) user_data;
130
131    callback_data * data = static_cast<callback_data *>(user_data);
132
133    auto print_progress_bar = [](int32_t step, int32_t total_steps) {
134        int progress_percent = (step * 100) / total_steps;
135        int progress_bars    = (step * 50) / total_steps;
136        LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
137                step,
138                total_steps,
139                std::string(progress_bars, '=').c_str(),
140                std::string(50 - progress_bars, ' ').c_str(),
141                progress_percent);
142    };
143
144    if (data->diff_params->visual_mode) {
145        // Visual mode: clear
146        LOG_INF("\033[2J\033[H");  // Clear screen and move cursor to top-left
147
148        print_progress_bar(step, total_steps);
149
150        LOG_INF("\n");
151
152        std::string current_text = " ";
153
154        for (int32_t i = data->n_input; i < n_tokens; i++) {
155            std::string token_str;
156            if (tokens[i] != llama_vocab_mask(data->vocab)) {
157                char piece[256];
158                int  n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
159                if (n_chars > 0) {
160                    piece[n_chars] = '\0';
161                    token_str      = piece;
162                }
163            } else {
164                token_str = " ";
165            }
166
167            current_text += token_str;
168        }
169
170        LOG_INF("%s\n", current_text.c_str());
171    } else {
172        print_progress_bar(step, total_steps);
173    }
174
175    return true;
176}
177
178static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
179    if (temperature == 0.0f) {
180        return;
181    }
182
183    std::uniform_real_distribution<double> uniform(0.0, 1.0);
184    for (int32_t i = 0; i < n_vocab; i++) {
185        double noise        = uniform(rng);
186        // Prevent log(0)
187        noise               = std::max(noise, 1e-20);
188        double gumbel_noise = std::pow(-std::log(noise), temperature);
189        logits[i]           = std::exp(logits[i]) / gumbel_noise;
190    }
191}
192
193static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
194    std::vector<int32_t> num_transfer_tokens(steps);
195
196    int32_t base      = mask_count / steps;
197    int32_t remainder = mask_count % steps;
198
199    for (int32_t i = 0; i < steps; i++) {
200        num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
201    }
202
203    return num_transfer_tokens;
204}
205
206static void diffusion_generate(llama_context *          ctx,
207                               const llama_token *      input_tokens,
208                               llama_token *            output_tokens,
209                               int32_t                  n_input,
210                               const diffusion_params & params,
211                               int32_t &                n_generated) {
212    n_generated = 0;
213    if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
214        return;
215    }
216
217    const llama_model * model = llama_get_model(ctx);
218
219    // Initialize with input and pad with mask tokens
220    std::copy(input_tokens, input_tokens + n_input, output_tokens);
221    std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
222
223    std::mt19937 rng(params.seed);
224
225    llama_set_causal_attn(ctx, false);
226
227    int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
228
229    std::vector<llama_token_data> candidates(n_vocab);
230    std::vector<llama_token_data> conf_candidates;
231    conf_candidates.reserve(params.max_length);
232    std::vector<int32_t> mask_positions;
233    mask_positions.reserve(params.max_length);
234
235    // Setup sampler chain
236    struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
237    if (params.top_k > 0) {
238        llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
239    }
240    if (params.top_p < 1.0f) {
241        llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
242    }
243    if (params.temperature > 0.0f) {
244        llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
245    }
246    llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
247
248    struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
249
250    llama_batch batch = llama_batch_init(params.max_length, 0, 1);
251    batch.n_tokens    = params.max_length;
252
253    // Pre-allocate buffers for CFG if needed
254    int32_t                  logits_size = n_vocab * params.max_length;
255    std::vector<float>       cond_logits_buffer;
256    std::vector<llama_token> un_x_buffer;
257    if (params.cfg_scale > 0.0f) {
258        cond_logits_buffer.resize(logits_size);
259        un_x_buffer.resize(params.max_length);
260    }
261
262    // For block-based processing
263    std::vector<int32_t> num_transfer_tokens;
264    int32_t              num_blocks      = 1;
265    int32_t              steps_per_block = params.steps;
266
267    if (params.schedule == BLOCK_BASED) {
268        GGML_ASSERT(params.max_length % params.block_length == 0);
269        num_blocks = params.max_length / params.block_length;
270        GGML_ASSERT(params.steps % num_blocks == 0);
271        steps_per_block = params.steps / num_blocks;
272    }
273
274    std::vector<float> confidence(params.max_length);
275
276    int64_t total_sampling_time = 0;
277    int64_t total_time          = 0;
278    int64_t time_start          = ggml_time_us();
279
280    for (int block_num = 0; block_num < num_blocks; block_num++) {
281        int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
282        int32_t block_end   = (params.schedule == BLOCK_BASED) ?
283                                  std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
284                                  params.max_length;
285
286        // Count masked tokens in current block for block-based processing
287        if (params.schedule == BLOCK_BASED) {
288            int32_t block_mask_count = 0;
289            for (int i = block_start; i < block_end; i++) {
290                if (output_tokens[i] == params.mask_token_id) {
291                    block_mask_count++;
292                }
293            }
294            num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
295        }
296
297        for (int32_t step = 0; step < steps_per_block; step++) {
298            int32_t global_step = block_num * steps_per_block + step;
299
300            if (params.step_callback) {
301                if (!params.step_callback(
302                        global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
303                    break;
304                }
305            }
306
307            // Setup batch
308            for (int32_t i = 0; i < params.max_length; i++) {
309                batch.token[i]     = output_tokens[i];
310                batch.pos[i]       = i;
311                batch.n_seq_id[i]  = 1;
312                batch.seq_id[i][0] = 0;
313                batch.logits[i]    = 1;
314            }
315
316            float * logits = nullptr;
317
318            if (params.cfg_scale > 0.0f) {
319                int ret = llama_decode(ctx, batch);
320                if (ret != 0) {
321                    LOG_ERR("Failed to generate conditional");
322                    break;
323                }
324                float * cond_logits_ptr = llama_get_logits(ctx);
325                std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
326
327                // Unconditional generation (mask input)
328                std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
329                for (int32_t i = 0; i < n_input; i++) {
330                    un_x_buffer[i] = params.mask_token_id;
331                }
332
333                for (int32_t i = 0; i < params.max_length; i++) {
334                    batch.token[i] = un_x_buffer[i];
335                }
336                ret = llama_decode(ctx, batch);
337                if (ret != 0) {
338                    LOG_ERR("Failed to generate unconditional");
339                    break;
340                }
341                float * uncond_logits = llama_get_logits(ctx);
342
343                // Apply CFG
344                for (int32_t i = 0; i < logits_size; i++) {
345                    cond_logits_buffer[i] =
346                        uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
347                }
348                logits = cond_logits_buffer.data();
349            } else {
350                int ret = llama_decode(ctx, batch);
351                if (ret != 0) {
352                    LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
353                    break;
354                }
355                logits = llama_get_logits(ctx);
356            }
357
358            if (!logits) {
359                LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
360                break;
361            }
362
363            auto get_logits_for_pos = [&](int32_t pos) -> const float * {
364                if (params.shift_logits) {
365                    return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
366                }
367                return logits + (pos) *n_vocab;
368            };
369
370            int64_t time_start_sampling = ggml_time_us();
371
372            mask_positions.clear();
373            for (int32_t i = 0; i < params.max_length; i++) {
374                if (output_tokens[i] == params.mask_token_id) {
375                    // For block-based, only consider current block
376                    if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
377                        mask_positions.push_back(i);
378                    }
379                }
380            }
381
382            if (mask_positions.empty()) {
383                break;
384            }
385
386            if (params.add_gumbel_noise && params.temperature > 0.0f) {
387                add_gumbel_noise(logits, n_vocab, params.temperature, rng);
388            }
389
390            if (params.algorithm == ORIGIN) {
391                int32_t transfer_count = calculate_transfer_count(
392                    step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
393                float p_transfer = (float) transfer_count / mask_positions.size();
394
395                for (int32_t pos : mask_positions) {
396                    if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
397                        const float * pos_logits = get_logits_for_pos(pos);
398                        for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
399                            candidates[token_id].id    = token_id;
400                            candidates[token_id].logit = pos_logits[token_id];
401                            candidates[token_id].p     = 0.0f;
402                        }
403
404                        llama_token_data_array cur_p = {
405                            candidates.data(),
406                            (size_t) n_vocab,
407                            -1,
408                            false,
409                        };
410
411                        llama_sampler_apply(sampler, &cur_p);
412                        output_tokens[pos] = cur_p.data[cur_p.selected].id;
413                    }
414                }
415            } else {
416                std::vector<std::pair<float, int32_t>> confidences;
417                std::vector<llama_token>               sampled_tokens(mask_positions.size());
418
419                for (size_t i = 0; i < mask_positions.size(); i++) {
420                    int32_t       pos        = mask_positions[i];
421                    const float * pos_logits = get_logits_for_pos(pos);
422
423                    for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
424                        candidates[token_id].logit = pos_logits[token_id];
425                        candidates[token_id].p     = 0.0f;
426                        candidates[token_id].id    = token_id;
427                    }
428
429                    llama_token_data_array cur_p = {
430                        candidates.data(),
431                        candidates.size(),
432                        -1,
433                        false,
434                    };
435
436                    llama_sampler_apply(sampler, &cur_p);
437                    llama_token sampled_token = cur_p.data[cur_p.selected].id;
438
439                    float conf = calculate_confidence(cur_p, params.algorithm, rng);
440
441                    sampled_tokens[i] = sampled_token;
442                    confidences.emplace_back(conf, i);
443                }
444
445                int32_t transfer_count = calculate_transfer_count(
446                    step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
447
448                if (transfer_count > 0) {
449                    if (params.alg_temp == 0.0f) {
450                        std::partial_sort(confidences.begin(),
451                                          confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
452                                          confidences.end(),
453                                          [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
454                                              if (a.first != b.first) {
455                                                  return a.first > b.first;
456                                              }
457                                              return a.second < b.second;
458                                          });
459
460                        for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
461                            int32_t mask_idx   = confidences[i].second;
462                            int32_t pos        = mask_positions[mask_idx];
463                            output_tokens[pos] = sampled_tokens[mask_idx];
464                        }
465                    } else {
466                        conf_candidates.clear();
467                        for (size_t i = 0; i < confidences.size(); i++) {
468                            float conf_logit = confidences[i].first / params.alg_temp;
469                            conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
470                        }
471
472                        llama_token_data_array conf_array = {
473                            conf_candidates.data(),
474                            conf_candidates.size(),
475                            -1,
476                            false,
477                        };
478
479                        for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
480                            llama_sampler_apply(dist_sampler, &conf_array);
481                            int32_t selected_idx = conf_array.selected;
482                            int32_t mask_idx     = selected_idx;
483                            int32_t pos          = mask_positions[mask_idx];
484                            output_tokens[pos]   = sampled_tokens[mask_idx];
485
486                            conf_candidates[selected_idx].p = 0.0f;
487                            conf_array.selected             = -1;
488                        }
489                    }
490                }
491            }
492
493            int64_t time_end_sampling = ggml_time_us();
494            total_sampling_time += time_end_sampling - time_start_sampling;
495        }
496    }
497
498    int64_t time_end = ggml_time_us();
499    total_time += time_end - time_start;
500
501    LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
502            total_time / 1000.0,
503            total_time / 1000.0 / params.steps,
504            total_sampling_time / 1000.0 / params.steps);
505
506    llama_batch_free(batch);
507    llama_sampler_free(sampler);
508    llama_sampler_free(dist_sampler);
509
510    n_generated = params.max_length;
511}
512
513static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
514    if (!use_chat_template) {
515        return prompt;
516    }
517
518    auto chat_templates = common_chat_templates_init(model, "");
519    common_chat_templates_inputs inputs;
520    common_chat_msg system_msg;
521
522    if (!system_prompt.empty()) {
523        system_msg.role = "system";
524        system_msg.content = system_prompt;
525        inputs.messages.push_back(system_msg);
526    }
527
528    common_chat_msg user_msg;
529    user_msg.role = "user";
530    user_msg.content = prompt;
531
532    inputs.messages.push_back(user_msg);
533    inputs.add_generation_prompt = true;
534
535    auto result = common_chat_templates_apply(chat_templates.get(), inputs);
536
537    return result.prompt;
538}
539
540int main(int argc, char ** argv) {
541    ggml_time_init();
542
543    common_params params;
544
545    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
546        return 1;
547    }
548
549    common_init();
550    llama_backend_init();
551
552    llama_model_params model_params = llama_model_default_params();
553    model_params.n_gpu_layers       = params.n_gpu_layers;
554    model_params.devices            = params.devices.data();
555    model_params.use_mmap           = params.use_mmap;
556    model_params.use_direct_io      = params.use_direct_io;
557    model_params.use_mlock          = params.use_mlock;
558    model_params.check_tensors      = params.check_tensors;
559
560    llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
561    if (!model) {
562        LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
563        return 1;
564    }
565
566    if (!llama_model_is_diffusion(model)) {
567        LOG_ERR("error: unsupported model for diffusion");
568        llama_model_free(model);
569        return 1;
570    }
571
572    llama_context_params ctx_params = llama_context_default_params();
573    ctx_params.n_ctx                = params.n_ctx;
574    ctx_params.n_batch              = params.n_batch;
575    ctx_params.n_ubatch             = params.n_ubatch;
576    ctx_params.flash_attn_type      = params.flash_attn_type;
577    ctx_params.no_perf              = params.no_perf;
578    ctx_params.type_k               = params.cache_type_k;
579    ctx_params.type_v               = params.cache_type_v;
580
581    llama_context * ctx = llama_init_from_model(model, ctx_params);
582    if (!ctx) {
583        LOG_ERR("error: failed to create context\n");
584        llama_model_free(model);
585        return 1;
586    }
587
588    llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
589
590    const llama_vocab * vocab            = llama_model_get_vocab(model);
591
592    std::string         formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
593
594    std::vector<llama_token> input_tokens = common_tokenize(vocab,
595                                                            formatted_prompt,
596                                                            /*add special tokens*/ true,
597                                                            /*parse special*/ true);
598
599    int n_input = input_tokens.size();
600
601    if (n_input >= params.n_ctx) {
602        LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
603        llama_free(ctx);
604        llama_model_free(model);
605        return 1;
606    }
607
608    llama_token mask_token_id = llama_vocab_mask(vocab);
609
610    GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
611
612    bool visual_mode = params.diffusion.visual_mode;
613
614    int32_t                  n_generated = 0;
615    std::vector<llama_token> output_tokens(params.n_ubatch);
616
617    struct diffusion_params diff_params;
618
619    char shift_logits_str[8];
620    if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
621        diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
622    } else {
623        diff_params.shift_logits = true;
624    }
625
626    //Use either eps or block length, but not both
627    GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
628
629    if (params.diffusion.eps) {
630        diff_params.schedule = TIMESTEP_BASED;
631        diff_params.eps      = params.diffusion.eps;
632    } else if (params.diffusion.block_length) {
633        diff_params.schedule     = BLOCK_BASED;
634        diff_params.block_length = params.diffusion.block_length;
635    }
636
637    diff_params.mask_token_id    = mask_token_id;
638    diff_params.seed             = params.sampling.seed;
639    diff_params.temperature      = params.sampling.temp;
640    diff_params.steps            = params.diffusion.steps;
641    diff_params.algorithm        = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
642    diff_params.max_length       = params.n_ubatch;
643    diff_params.top_p            = params.sampling.top_p;
644    diff_params.top_k            = params.sampling.top_k;
645    diff_params.visual_mode      = params.diffusion.visual_mode;
646    diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
647
648    diff_params.step_callback           = diffusion_step_callback;
649    callback_data cb_data               = { &diff_params, vocab, n_input };
650    diff_params.step_callback_user_data = &cb_data;
651
652    const char * alg_names[]   = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
653    const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
654    const char * alg_name =
655        (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
656    const char * sched_name =
657        (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
658
659    LOG_INF("diffusion_params: - %-25s llama_token      = %d\n", "mask_token_id", mask_token_id);
660    LOG_INF("diffusion_params: - %-25s u32              = %d\n", "steps", diff_params.steps);
661    LOG_INF("diffusion_params: - %-25s u32              = %d\n", "max_length", diff_params.max_length);
662    LOG_INF("diffusion_params: - %-25s enum             = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
663    LOG_INF("diffusion_params: - %-25s enum             = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
664    LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "temperature", diff_params.temperature);
665    if (diff_params.schedule == TIMESTEP_BASED) {
666        LOG_INF("diffusion_params: - %-25s f32              = %.6f\n", "eps", diff_params.eps);
667        LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "alg_temp", diff_params.alg_temp);
668    }
669    if (diff_params.schedule == BLOCK_BASED) {
670        LOG_INF("diffusion_params: - %-25s u32              = %d\n", "block_length", diff_params.block_length);
671        LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "cfg_scale", diff_params.cfg_scale);
672    }
673
674    diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
675
676    if (n_generated > 0) {
677        if (visual_mode) {
678            //clear screen and move cursor to top-left
679            LOG_INF("\033[2J\033[H");
680        }
681
682        output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
683        std::string output_data = common_detokenize(vocab, output_tokens, false);
684        LOG_INF("\n%s\n", output_data.c_str());
685    } else {
686        LOG_INF("Error: diffusion generation failed\n");
687    }
688
689    llama_free(ctx);
690    llama_model_free(model);
691    llama_backend_free();
692
693    return 0;
694}