diff options
Diffstat (limited to 'llama.cpp/examples/eval-callback/eval-callback.cpp')
| -rw-r--r-- | llama.cpp/examples/eval-callback/eval-callback.cpp | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/llama.cpp/examples/eval-callback/eval-callback.cpp b/llama.cpp/examples/eval-callback/eval-callback.cpp new file mode 100644 index 0000000..bd58734 --- /dev/null +++ b/llama.cpp/examples/eval-callback/eval-callback.cpp @@ -0,0 +1,80 @@ +#include "arg.h" +#include "common.h" +#include "debug.h" +#include "log.h" +#include "llama.h" +#include "llama-cpp.h" +#include <string> +#include <vector> + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + return true; +} + +int main(int argc, char ** argv) { + base_callback_data cb_data; + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + // pass the callback to the backend scheduler + // it will be executed for each node during the graph computation + params.cb_eval = common_debug_cb_eval<false>; + params.cb_eval_user_data = &cb_data; + params.warmup = false; + + // init + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + bool OK = run(ctx, params); + if (!OK) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} |
