diff options
Diffstat (limited to 'llama.cpp/examples/diffusion')
| -rw-r--r-- | llama.cpp/examples/diffusion/CMakeLists.txt | 5 | ||||
| -rw-r--r-- | llama.cpp/examples/diffusion/README.md | 59 | ||||
| -rw-r--r-- | llama.cpp/examples/diffusion/diffusion-cli.cpp | 694 |
3 files changed, 758 insertions, 0 deletions
diff --git a/llama.cpp/examples/diffusion/CMakeLists.txt b/llama.cpp/examples/diffusion/CMakeLists.txt new file mode 100644 index 0000000..396549c --- /dev/null +++ b/llama.cpp/examples/diffusion/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-diffusion-cli) +add_executable(${TARGET} diffusion-cli.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/diffusion/README.md b/llama.cpp/examples/diffusion/README.md new file mode 100644 index 0000000..f71d241 --- /dev/null +++ b/llama.cpp/examples/diffusion/README.md @@ -0,0 +1,59 @@ +# Diffusion Text Generation + +This directory contains implementations for Diffusion LLMs (DLLMs) + +More Info: +- https://github.com/ggml-org/llama.cpp/pull/14644 +- https://github.com/ggml-org/llama.cpp/pull/14771 + +## Parameters +The diffusion CLI supports various parameters to control the generation process: + +### Core Diffusion Parameters +- `--diffusion-steps`: Number of diffusion steps (default: 256) +- `--diffusion-algorithm`: Algorithm for token selection + - `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. + - `1`: ENTROPY_BASED - Entropy-based selection + - `2`: MARGIN_BASED - Margin-based selection + - `3`: RANDOM - Random selection + - `4`: CONFIDENCE_BASED - Confidence-based selection (default) + - More documentation here https://github.com/DreamLM/Dream +- `--diffusion-visual`: Enable live visualization during generation + +### Scheduling Parameters +Choose one of the following scheduling methods: + +**Timestep-based scheduling:** +- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001) + +**Block-based scheduling:** +- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32) + +### Sampling Parameters +- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random) +- `--top-k`: Top-k filtering for sampling +- `--top-p`: Top-p (nucleus) filtering for sampling +- `--seed`: Random seed for reproducibility + +### Model Parameters +- `-m`: Path to the GGUF model file +- `-p`: Input prompt text +- `-ub`: Maximum sequence length (ubatch size) +- `-c`: Context size +- `-b`: Batch size + +### Examples +#### Dream architechture: +``` +llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual +``` + +#### LLaDA architechture: +``` +llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual +``` + +#### RND1 architecture: +``` +llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001 +``` diff --git a/llama.cpp/examples/diffusion/diffusion-cli.cpp b/llama.cpp/examples/diffusion/diffusion-cli.cpp new file mode 100644 index 0000000..d50f754 --- /dev/null +++ b/llama.cpp/examples/diffusion/diffusion-cli.cpp @@ -0,0 +1,694 @@ +#include "arg.h" +#include "chat.h" +#include "common.h" +#include "llama.h" +#include "log.h" + +#include <limits.h> + +#include <algorithm> +#include <cmath> +#include <cstring> +#include <limits> +#include <random> +#include <string> +#include <vector> + +enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 }; + +// Unified transfer scheduling methods +enum transfer_schedule { + TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining + BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens +}; + +typedef bool (*diffusion_step_callback_t)(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data); + +struct diffusion_params { + int32_t steps = 0; + float temperature = 0; + llama_token mask_token_id = LLAMA_TOKEN_NULL; + diffusion_step_callback_t step_callback = nullptr; + void * step_callback_user_data = nullptr; + int32_t seed = 0; + bool visual_mode = false; + bool shift_logits = false; // Shift logits by -1 after decode + + float top_p = 0.; + int32_t top_k = 0.; + + diffusion_algorithm algorithm = CONFIDENCE_BASED; + transfer_schedule schedule = TIMESTEP_BASED; + + float cfg_scale = 0.; // Config scale for classifier-free guidance + float eps = 0.; // Timestep scheduling + int32_t block_length = 0; // Block size (for block scheduling) + float alg_temp = 0; // algorithm temperature (0.0 = deterministic) + bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0 + + int32_t max_length = 0; // Maximum sequence length +}; + +struct callback_data { + diffusion_params * diff_params; + const llama_vocab * vocab; + int32_t n_input; +}; + +static float calculate_confidence(const llama_token_data_array & cur_p, + diffusion_algorithm algorithm, + std::mt19937 & rng) { + switch (algorithm) { + case CONFIDENCE_BASED: + return cur_p.data[cur_p.selected].p; // Selected token probability + + case ENTROPY_BASED: + { + float entropy = 0.0f; + const float epsilon = 1e-10f; + for (size_t i = 0; i < cur_p.size; i++) { + float prob = cur_p.data[i].p; + entropy += prob * logf(prob + epsilon); + } + return -entropy; // Higher entropy = lower confidence + } + + case MARGIN_BASED: + return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p; + + case RANDOM: + { + std::uniform_real_distribution<float> uniform(0.0f, 1.0f); + return uniform(rng); // Random confidence + } + + case ORIGIN: + return cur_p.data[cur_p.selected].p; + + default: + return 0.0f; + } +} + +// Unified transfer count calculation function +static int32_t calculate_transfer_count(int32_t step, + int32_t total_steps, + int32_t remaining_masked, + transfer_schedule schedule, + float eps, + const std::vector<int32_t> & num_transfer_tokens = {}) { + switch (schedule) { + case TIMESTEP_BASED: + { + float t = 1.0f - (float) step / total_steps * (1.0f - eps); + float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps); + float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f; + return (int32_t) (remaining_masked * p_transfer); + } + + case BLOCK_BASED: + if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) { + return num_transfer_tokens[step]; + } + return remaining_masked / (total_steps - step); // Fallback + + default: + return remaining_masked / (total_steps - step); + } +} + +static bool diffusion_step_callback(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data) { + (void) user_data; + + callback_data * data = static_cast<callback_data *>(user_data); + + auto print_progress_bar = [](int32_t step, int32_t total_steps) { + int progress_percent = (step * 100) / total_steps; + int progress_bars = (step * 50) / total_steps; + LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%", + step, + total_steps, + std::string(progress_bars, '=').c_str(), + std::string(50 - progress_bars, ' ').c_str(), + progress_percent); + }; + + if (data->diff_params->visual_mode) { + // Visual mode: clear + LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left + + print_progress_bar(step, total_steps); + + LOG_INF("\n"); + + std::string current_text = " "; + + for (int32_t i = data->n_input; i < n_tokens; i++) { + std::string token_str; + if (tokens[i] != llama_vocab_mask(data->vocab)) { + char piece[256]; + int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false); + if (n_chars > 0) { + piece[n_chars] = '\0'; + token_str = piece; + } + } else { + token_str = " "; + } + + current_text += token_str; + } + + LOG_INF("%s\n", current_text.c_str()); + } else { + print_progress_bar(step, total_steps); + } + + return true; +} + +static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) { + if (temperature == 0.0f) { + return; + } + + std::uniform_real_distribution<double> uniform(0.0, 1.0); + for (int32_t i = 0; i < n_vocab; i++) { + double noise = uniform(rng); + // Prevent log(0) + noise = std::max(noise, 1e-20); + double gumbel_noise = std::pow(-std::log(noise), temperature); + logits[i] = std::exp(logits[i]) / gumbel_noise; + } +} + +static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) { + std::vector<int32_t> num_transfer_tokens(steps); + + int32_t base = mask_count / steps; + int32_t remainder = mask_count % steps; + + for (int32_t i = 0; i < steps; i++) { + num_transfer_tokens[i] = base + (i < remainder ? 1 : 0); + } + + return num_transfer_tokens; +} + +static void diffusion_generate(llama_context * ctx, + const llama_token * input_tokens, + llama_token * output_tokens, + int32_t n_input, + const diffusion_params & params, + int32_t & n_generated) { + n_generated = 0; + if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) { + return; + } + + const llama_model * model = llama_get_model(ctx); + + // Initialize with input and pad with mask tokens + std::copy(input_tokens, input_tokens + n_input, output_tokens); + std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id); + + std::mt19937 rng(params.seed); + + llama_set_causal_attn(ctx, false); + + int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); + + std::vector<llama_token_data> candidates(n_vocab); + std::vector<llama_token_data> conf_candidates; + conf_candidates.reserve(params.max_length); + std::vector<int32_t> mask_positions; + mask_positions.reserve(params.max_length); + + // Setup sampler chain + struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); + if (params.top_k > 0) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); + } + if (params.top_p < 1.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); + } + if (params.temperature > 0.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); + } + llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); + + struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); + + llama_batch batch = llama_batch_init(params.max_length, 0, 1); + batch.n_tokens = params.max_length; + + // Pre-allocate buffers for CFG if needed + int32_t logits_size = n_vocab * params.max_length; + std::vector<float> cond_logits_buffer; + std::vector<llama_token> un_x_buffer; + if (params.cfg_scale > 0.0f) { + cond_logits_buffer.resize(logits_size); + un_x_buffer.resize(params.max_length); + } + + // For block-based processing + std::vector<int32_t> num_transfer_tokens; + int32_t num_blocks = 1; + int32_t steps_per_block = params.steps; + + if (params.schedule == BLOCK_BASED) { + GGML_ASSERT(params.max_length % params.block_length == 0); + num_blocks = params.max_length / params.block_length; + GGML_ASSERT(params.steps % num_blocks == 0); + steps_per_block = params.steps / num_blocks; + } + + std::vector<float> confidence(params.max_length); + + int64_t total_sampling_time = 0; + int64_t total_time = 0; + int64_t time_start = ggml_time_us(); + + for (int block_num = 0; block_num < num_blocks; block_num++) { + int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0; + int32_t block_end = (params.schedule == BLOCK_BASED) ? + std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : + params.max_length; + + // Count masked tokens in current block for block-based processing + if (params.schedule == BLOCK_BASED) { + int32_t block_mask_count = 0; + for (int i = block_start; i < block_end; i++) { + if (output_tokens[i] == params.mask_token_id) { + block_mask_count++; + } + } + num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block); + } + + for (int32_t step = 0; step < steps_per_block; step++) { + int32_t global_step = block_num * steps_per_block + step; + + if (params.step_callback) { + if (!params.step_callback( + global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) { + break; + } + } + + // Setup batch + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = output_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; + } + + float * logits = nullptr; + + if (params.cfg_scale > 0.0f) { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate conditional"); + break; + } + float * cond_logits_ptr = llama_get_logits(ctx); + std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float)); + + // Unconditional generation (mask input) + std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin()); + for (int32_t i = 0; i < n_input; i++) { + un_x_buffer[i] = params.mask_token_id; + } + + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = un_x_buffer[i]; + } + ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate unconditional"); + break; + } + float * uncond_logits = llama_get_logits(ctx); + + // Apply CFG + for (int32_t i = 0; i < logits_size; i++) { + cond_logits_buffer[i] = + uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]); + } + logits = cond_logits_buffer.data(); + } else { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret); + break; + } + logits = llama_get_logits(ctx); + } + + if (!logits) { + LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step); + break; + } + + auto get_logits_for_pos = [&](int32_t pos) -> const float * { + if (params.shift_logits) { + return pos == 0 ? logits : logits + (pos - 1) * n_vocab; + } + return logits + (pos) *n_vocab; + }; + + int64_t time_start_sampling = ggml_time_us(); + + mask_positions.clear(); + for (int32_t i = 0; i < params.max_length; i++) { + if (output_tokens[i] == params.mask_token_id) { + // For block-based, only consider current block + if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) { + mask_positions.push_back(i); + } + } + } + + if (mask_positions.empty()) { + break; + } + + if (params.add_gumbel_noise && params.temperature > 0.0f) { + add_gumbel_noise(logits, n_vocab, params.temperature, rng); + } + + if (params.algorithm == ORIGIN) { + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + float p_transfer = (float) transfer_count / mask_positions.size(); + + for (int32_t pos : mask_positions) { + if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) { + const float * pos_logits = get_logits_for_pos(pos); + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].id = token_id; + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + } + + llama_token_data_array cur_p = { + candidates.data(), + (size_t) n_vocab, + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + output_tokens[pos] = cur_p.data[cur_p.selected].id; + } + } + } else { + std::vector<std::pair<float, int32_t>> confidences; + std::vector<llama_token> sampled_tokens(mask_positions.size()); + + for (size_t i = 0; i < mask_positions.size(); i++) { + int32_t pos = mask_positions[i]; + const float * pos_logits = get_logits_for_pos(pos); + + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + candidates[token_id].id = token_id; + } + + llama_token_data_array cur_p = { + candidates.data(), + candidates.size(), + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + llama_token sampled_token = cur_p.data[cur_p.selected].id; + + float conf = calculate_confidence(cur_p, params.algorithm, rng); + + sampled_tokens[i] = sampled_token; + confidences.emplace_back(conf, i); + } + + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + + if (transfer_count > 0) { + if (params.alg_temp == 0.0f) { + std::partial_sort(confidences.begin(), + confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), + confidences.end(), + [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + int32_t mask_idx = confidences[i].second; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + } + } else { + conf_candidates.clear(); + for (size_t i = 0; i < confidences.size(); i++) { + float conf_logit = confidences[i].first / params.alg_temp; + conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); + } + + llama_token_data_array conf_array = { + conf_candidates.data(), + conf_candidates.size(), + -1, + false, + }; + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + llama_sampler_apply(dist_sampler, &conf_array); + int32_t selected_idx = conf_array.selected; + int32_t mask_idx = selected_idx; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + + conf_candidates[selected_idx].p = 0.0f; + conf_array.selected = -1; + } + } + } + } + + int64_t time_end_sampling = ggml_time_us(); + total_sampling_time += time_end_sampling - time_start_sampling; + } + } + + int64_t time_end = ggml_time_us(); + total_time += time_end - time_start; + + LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", + total_time / 1000.0, + total_time / 1000.0 / params.steps, + total_sampling_time / 1000.0 / params.steps); + + llama_batch_free(batch); + llama_sampler_free(sampler); + llama_sampler_free(dist_sampler); + + n_generated = params.max_length; +} + +static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) { + if (!use_chat_template) { + return prompt; + } + + auto chat_templates = common_chat_templates_init(model, ""); + common_chat_templates_inputs inputs; + common_chat_msg system_msg; + + if (!system_prompt.empty()) { + system_msg.role = "system"; + system_msg.content = system_prompt; + inputs.messages.push_back(system_msg); + } + + common_chat_msg user_msg; + user_msg.role = "user"; + user_msg.content = prompt; + + inputs.messages.push_back(user_msg); + inputs.add_generation_prompt = true; + + auto result = common_chat_templates_apply(chat_templates.get(), inputs); + + return result.prompt; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) { + return 1; + } + + common_init(); + llama_backend_init(); + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = params.n_gpu_layers; + model_params.devices = params.devices.data(); + model_params.use_mmap = params.use_mmap; + model_params.use_direct_io = params.use_direct_io; + model_params.use_mlock = params.use_mlock; + model_params.check_tensors = params.check_tensors; + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + if (!model) { + LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str()); + return 1; + } + + if (!llama_model_is_diffusion(model)) { + LOG_ERR("error: unsupported model for diffusion"); + llama_model_free(model); + return 1; + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = params.n_ctx; + ctx_params.n_batch = params.n_batch; + ctx_params.n_ubatch = params.n_ubatch; + ctx_params.flash_attn_type = params.flash_attn_type; + ctx_params.no_perf = params.no_perf; + ctx_params.type_k = params.cache_type_k; + ctx_params.type_v = params.cache_type_v; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + LOG_ERR("error: failed to create context\n"); + llama_model_free(model); + return 1; + } + + llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads); + + const llama_vocab * vocab = llama_model_get_vocab(model); + + std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model); + + std::vector<llama_token> input_tokens = common_tokenize(vocab, + formatted_prompt, + /*add special tokens*/ true, + /*parse special*/ true); + + int n_input = input_tokens.size(); + + if (n_input >= params.n_ctx) { + LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx); + llama_free(ctx); + llama_model_free(model); + return 1; + } + + llama_token mask_token_id = llama_vocab_mask(vocab); + + GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL); + + bool visual_mode = params.diffusion.visual_mode; + + int32_t n_generated = 0; + std::vector<llama_token> output_tokens(params.n_ubatch); + + struct diffusion_params diff_params; + + char shift_logits_str[8]; + if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) { + diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0); + } else { + diff_params.shift_logits = true; + } + + //Use either eps or block length, but not both + GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0)); + + if (params.diffusion.eps) { + diff_params.schedule = TIMESTEP_BASED; + diff_params.eps = params.diffusion.eps; + } else if (params.diffusion.block_length) { + diff_params.schedule = BLOCK_BASED; + diff_params.block_length = params.diffusion.block_length; + } + + diff_params.mask_token_id = mask_token_id; + diff_params.seed = params.sampling.seed; + diff_params.temperature = params.sampling.temp; + diff_params.steps = params.diffusion.steps; + diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm); + diff_params.max_length = params.n_ubatch; + diff_params.top_p = params.sampling.top_p; + diff_params.top_k = params.sampling.top_k; + diff_params.visual_mode = params.diffusion.visual_mode; + diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise; + + diff_params.step_callback = diffusion_step_callback; + callback_data cb_data = { &diff_params, vocab, n_input }; + diff_params.step_callback_user_data = &cb_data; + + const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" }; + const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" }; + const char * alg_name = + (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN"; + const char * sched_name = + (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN"; + + LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length); + LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name); + LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature); + if (diff_params.schedule == TIMESTEP_BASED) { + LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp); + } + if (diff_params.schedule == BLOCK_BASED) { + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale); + } + + diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated); + + if (n_generated > 0) { + if (visual_mode) { + //clear screen and move cursor to top-left + LOG_INF("\033[2J\033[H"); + } + + output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input); + std::string output_data = common_detokenize(vocab, output_tokens, false); + LOG_INF("\n%s\n", output_data.c_str()); + } else { + LOG_INF("Error: diffusion generation failed\n"); + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + + return 0; +} |
