diff options
Diffstat (limited to 'llama.cpp/examples')
212 files changed, 30084 insertions, 0 deletions
diff --git a/llama.cpp/examples/CMakeLists.txt b/llama.cpp/examples/CMakeLists.txt new file mode 100644 index 0000000..a29dc70 --- /dev/null +++ b/llama.cpp/examples/CMakeLists.txt @@ -0,0 +1,45 @@ +# dependencies + +find_package(Threads REQUIRED) + +# third-party + +# ... + +# flags + +llama_add_compile_flags() + +# examples + +if (EMSCRIPTEN) +else() + add_subdirectory(batched) + add_subdirectory(debug) + add_subdirectory(embedding) + add_subdirectory(eval-callback) + + add_subdirectory(gguf-hash) + add_subdirectory(gguf) + add_subdirectory(idle) + add_subdirectory(lookahead) + add_subdirectory(lookup) + add_subdirectory(parallel) + add_subdirectory(passkey) + add_subdirectory(retrieval) + add_subdirectory(save-load-state) + add_subdirectory(simple) + add_subdirectory(simple-chat) + add_subdirectory(speculative) + add_subdirectory(speculative-simple) + add_subdirectory(gen-docs) + add_subdirectory(training) + add_subdirectory(diffusion) + if (NOT GGML_BACKEND_DL) + add_subdirectory(convert-llama2c-to-ggml) + # these examples use the backends directly and cannot be built with dynamic loading + if (GGML_SYCL) + add_subdirectory(sycl) + endif() + endif() +endif() diff --git a/llama.cpp/examples/batched.swift/.gitignore b/llama.cpp/examples/batched.swift/.gitignore new file mode 100644 index 0000000..e1e863b --- /dev/null +++ b/llama.cpp/examples/batched.swift/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc +batched_swift diff --git a/llama.cpp/examples/batched.swift/Makefile b/llama.cpp/examples/batched.swift/Makefile new file mode 100755 index 0000000..1f9156e --- /dev/null +++ b/llama.cpp/examples/batched.swift/Makefile @@ -0,0 +1,6 @@ +.PHONY: build + +build: + xcodebuild -scheme llama-batched-swift -destination "generic/platform=macOS" -derivedDataPath build + rm -f ./llama-batched-swift + ln -s ./build/Build/Products/Debug/llama-batched-swift ./llama-batched-swift diff --git a/llama.cpp/examples/batched.swift/Package.swift b/llama.cpp/examples/batched.swift/Package.swift new file mode 100644 index 0000000..7e8afd0 --- /dev/null +++ b/llama.cpp/examples/batched.swift/Package.swift @@ -0,0 +1,22 @@ +// swift-tools-version: 5.5 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "llama-batched-swift", + platforms: [.macOS(.v12)], + dependencies: [ + .package(name: "llama", path: "../../"), + ], + targets: [ + // Targets are the basic building blocks of a package, defining a module or a test suite. + // Targets can depend on other targets in this package and products from dependencies. + .executableTarget( + name: "llama-batched-swift", + dependencies: ["llama"], + path: "Sources", + linkerSettings: [.linkedFramework("Foundation"), .linkedFramework("AppKit")] + ), + ] +) diff --git a/llama.cpp/examples/batched.swift/README.md b/llama.cpp/examples/batched.swift/README.md new file mode 100644 index 0000000..f089015 --- /dev/null +++ b/llama.cpp/examples/batched.swift/README.md @@ -0,0 +1,5 @@ +This is a swift clone of `examples/batched`. + +```bash +$ ./llama-batched-swift MODEL_PATH [PROMPT] [PARALLEL] +``` diff --git a/llama.cpp/examples/batched.swift/Sources/main.swift b/llama.cpp/examples/batched.swift/Sources/main.swift new file mode 100644 index 0000000..fd90bbe --- /dev/null +++ b/llama.cpp/examples/batched.swift/Sources/main.swift @@ -0,0 +1,256 @@ +import Foundation +import llama + +let arguments = CommandLine.arguments + +// Check that we have at least one argument (the model path) +guard arguments.count > 1 else { + print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]") + exit(1) +} + +let modelPath: String = arguments[1] +let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is" +let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1 + +// total length of the sequences including the prompt +let n_len: Int = 32 + +// init LLM +llama_backend_init() +defer { + llama_backend_free() +} + +let model_params = llama_model_default_params() +guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else { + print("Failed to load model") + exit(1) +} +defer { + llama_model_free(model) +} + +guard let vocab = llama_model_get_vocab(model) else { + print("Failed to get vocab") + exit(1) +} + +var tokens = tokenize(text: prompt, add_bos: true) + +let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) + +var context_params = llama_context_default_params() +context_params.n_ctx = n_kv_req +context_params.n_batch = UInt32(max(n_len, n_parallel)) +context_params.n_threads = 8 +context_params.n_threads_batch = 8 + +let context = llama_init_from_model(model, context_params) +guard context != nil else { + print("Failed to initialize context") + exit(1) +} +defer { + llama_free(context) +} + +var sparams = llama_sampler_chain_default_params() + +let smpl = llama_sampler_chain_init(sparams) +guard smpl != nil else { + print("Failed to initialize sampling") + exit(1) +} +defer { + llama_sampler_free(smpl) +} + +llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40)); +llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); +llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4)); +llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234)); + +let n_ctx = llama_n_ctx(context) + +print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") + +if n_kv_req > n_ctx { + print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req) + exit(1) +} + +var buffer: [CChar] = [] +for id: llama_token in tokens { + print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "") +} + +print("\n") + +var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1) +defer { + llama_batch_free(batch) +} + +// evaluate the initial prompt +batch.n_tokens = Int32(tokens.count) + +for (i, token) in tokens.enumerated() { + batch.token[i] = token + batch.pos[i] = Int32(i) + batch.n_seq_id[i] = 1 + // batch.seq_id[i][0] = 0 + // TODO: is this the proper way to do this? + if let seq_id = batch.seq_id[i] { + seq_id[0] = 0 + } + batch.logits[i] = 0 +} + +// llama_decode will output logits only for the last token of the prompt +batch.logits[Int(batch.n_tokens) - 1] = 1 + +if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + exit(1) +} + +for i in 1 ..< n_parallel { + llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens) +} + +if n_parallel > 1 { + print("generating \(n_parallel) sequences ...\n") +} + +var streams: [String] = .init(repeating: "", count: n_parallel) +var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel) +var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel) + +var n_cur = batch.n_tokens +var n_decode = 0 + +let t_main_start = ggml_time_us() + +while n_cur <= n_len { + // prepare the next batch + batch.n_tokens = 0 + + // sample the next token for each parallel sequence / stream + for i in 0 ..< n_parallel { + if i_batch[i] < 0 { + // the stream has already finished + continue + } + + let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) + + // is it an end of stream? -> mark the stream as finished + if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { + i_batch[i] = -1 + // print("") + if n_parallel > 1 { + print("stream \(i) finished at n_cur = \(n_cur)") + } + + continue + } + + let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? "" + + // if there is only one stream, we print immediately to stdout + if n_parallel == 1 { + print(nextStringPiece, terminator: "") + } + streams[i] += nextStringPiece + + // push this new token for next evaluation + batch.token[Int(batch.n_tokens)] = new_token_id + batch.pos[Int(batch.n_tokens)] = n_cur + batch.n_seq_id[Int(batch.n_tokens)] = 1 + if let seq_id = batch.seq_id[Int(batch.n_tokens)] { + seq_id[0] = Int32(i) + } + batch.logits[Int(batch.n_tokens)] = 1 + + i_batch[i] = batch.n_tokens + + batch.n_tokens += 1 + + n_decode += 1 + } + + // all streams are finished + if batch.n_tokens == 0 { + break + } + + n_cur += 1 + + // evaluate the current batch with the transformer model + if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + exit(1) + } +} + +if n_parallel > 1 { + print("\n") + for (i, stream) in streams.enumerated() { + print("sequence \(i):\n\n\(prompt)\(stream)\n") + } +} + +let t_main_end = ggml_time_us() + +print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n") + +llama_perf_sampler_print(smpl) +llama_perf_context_print(context) + +private func tokenize(text: String, add_bos: Bool) -> [llama_token] { + let utf8Count = text.utf8.count + let n_tokens = utf8Count + (add_bos ? 1 : 0) + let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens) + let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false) + var swiftTokens: [llama_token] = [] + for i in 0 ..< tokenCount { + swiftTokens.append(tokens[Int(i)]) + } + tokens.deallocate() + return swiftTokens +} + +private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? { + var result = [CChar](repeating: 0, count: 8) + let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false) + if nTokens < 0 { + let actualTokensCount = -Int(nTokens) + result = .init(repeating: 0, count: actualTokensCount) + let check = llama_token_to_piece( + vocab, + token, + &result, + Int32(result.count), + 0, + false + ) + assert(check == actualTokensCount) + } else { + result.removeLast(result.count - Int(nTokens)) + } + if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) { + return utfString + } else { + buffer.append(contentsOf: result) + let data = Data(buffer.map { UInt8(bitPattern: $0) }) + if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer + buffer = [] + } + guard let bufferString = String(data: data, encoding: .utf8) else { + return nil + } + buffer = [] + return bufferString + } +} diff --git a/llama.cpp/examples/batched/CMakeLists.txt b/llama.cpp/examples/batched/CMakeLists.txt new file mode 100644 index 0000000..0d439f4 --- /dev/null +++ b/llama.cpp/examples/batched/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-batched) +add_executable(${TARGET} batched.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/batched/README.md b/llama.cpp/examples/batched/README.md new file mode 100644 index 0000000..8cde35d --- /dev/null +++ b/llama.cpp/examples/batched/README.md @@ -0,0 +1,44 @@ +# llama.cpp/example/batched + +The example demonstrates batched generation from a given prompt + +```bash +./llama-batched -m ./models/llama-7b-v2/ggml-model-f16.gguf -p "Hello my name is" -np 4 --kv-unified + +... + +main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113 + + Hello my name is + +main: generating 4 sequences ... + +main: stream 0 finished +main: stream 1 finished +main: stream 2 finished +main: stream 3 finished + +sequence 0: + +Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b + +sequence 1: + +Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between + +sequence 2: + +Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am + +sequence 3: + +Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and + +main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s + +llama_print_timings: load time = 587.00 ms +llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second) +llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second) +llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second) +llama_print_timings: total time = 4156.04 ms +``` diff --git a/llama.cpp/examples/batched/batched.cpp b/llama.cpp/examples/batched/batched.cpp new file mode 100644 index 0000000..a8c19a6 --- /dev/null +++ b/llama.cpp/examples/batched/batched.cpp @@ -0,0 +1,261 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "sampling.h" + +#include <algorithm> +#include <cstdio> +#include <string> +#include <vector> + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32 -np 4\n", argv[0]); + LOG("\n"); +} + +int main(int argc, char ** argv) { + common_params params; + + params.prompt = "Hello my name is"; + params.n_predict = 32; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BATCHED, print_usage)) { + return 1; + } + + common_init(); + + // number of parallel batches + int n_parallel = params.n_parallel; + + // total length of the sequences including the prompt + int n_predict = params.n_predict; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = common_model_params_to_llama(params); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + + if (model == NULL) { + LOG_ERR("%s: error: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // tokenize the prompt + + std::vector<llama_token> tokens_list; + tokens_list = common_tokenize(vocab, params.prompt, true); + + const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel; + + // initialize the context + + llama_context_params ctx_params = common_context_params_to_llama(params); + + ctx_params.n_ctx = n_kv_req; + ctx_params.n_batch = std::max(n_predict, n_parallel); + + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + + std::vector<llama_sampler_seq_config> sampler_configs; + + for (int32_t i = 0; i < n_parallel; ++i) { + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sampling.top_p, params.sampling.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); + + sampler_configs.push_back({ i, smpl }); + } + + if (params.sampling.backend_sampling) { + ctx_params.samplers = sampler_configs.data(); + ctx_params.n_samplers = sampler_configs.size(); + } + + llama_context * ctx = llama_init_from_model(model, ctx_params); + + if (ctx == NULL) { + LOG_ERR("%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + const int n_ctx = llama_n_ctx(ctx); + + LOG_INF("\n%s: n_predict = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req); + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_ERR("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req); + LOG_ERR("%s: either reduce n_parallel or increase n_ctx\n", __func__); + return 1; + } + + // print the prompt token-by-token + + LOG("\n"); + + for (auto id : tokens_list) { + LOG("%s", common_token_to_piece(ctx, id).c_str()); + } + + // create a llama_batch + // we use this object to submit token data for decoding + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); + + std::vector<llama_seq_id> seq_ids(n_parallel, 0); + for (int32_t i = 0; i < n_parallel; ++i) { + seq_ids[i] = i; + } + + // evaluate the initial prompt + for (size_t i = 0; i < tokens_list.size(); ++i) { + common_batch_add(batch, tokens_list[i], i, seq_ids, false); + } + GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); + + if (llama_model_has_encoder(model)) { + if (llama_encode(ctx, batch)) { + LOG_ERR("%s : failed to eval\n", __func__); + return 1; + } + + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = llama_vocab_bos(vocab); + } + + common_batch_clear(batch); + common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + } + + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + //// assign the system KV cache to all parallel sequences + //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them + //for (int32_t i = 1; i < n_parallel; ++i) { + // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + //} + + if (n_parallel > 1) { + LOG("\n\n%s: generating %d sequences ...\n", __func__, n_parallel); + } + + // main loop + + // we will store the parallel decoded sequences in this vector + std::vector<std::string> streams(n_parallel); + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1); + + int n_cur = batch.n_tokens; + int n_decode = 0; + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_predict) { + // prepare the next batch + common_batch_clear(batch); + + // sample the next token for each parallel sequence / stream + for (int32_t i = 0; i < n_parallel; ++i) { + if (i_batch[i] < 0) { + // the stream has already finished + continue; + } + + const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]); + + // is it an end of generation? -> mark the stream as finished + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { + i_batch[i] = -1; + LOG("\n"); + if (n_parallel > 1) { + LOG_INF("%s: stream %d finished at n_cur = %d", __func__, i, n_cur); + } + + continue; + } + + // if there is only one stream, we print immediately to stdout + if (n_parallel == 1) { + LOG("%s", common_token_to_piece(ctx, new_token_id).c_str()); + } + + streams[i] += common_token_to_piece(ctx, new_token_id); + + i_batch[i] = batch.n_tokens; + + // push this new token for next evaluation + common_batch_add(batch, new_token_id, n_cur, { i }, true); + + n_decode += 1; + } + + // all streams are finished + if (batch.n_tokens == 0) { + break; + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + if (n_parallel > 1) { + LOG("\n"); + + for (int32_t i = 0; i < n_parallel; ++i) { + LOG("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str()); + } + } + + const auto t_main_end = ggml_time_us(); + + LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + LOG("\n"); + llama_perf_sampler_print(sampler_configs[0].sampler); + llama_perf_context_print(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + + for (auto & sampler_config : sampler_configs) { + llama_sampler_free(sampler_config.sampler); + } + + llama_free(ctx); + llama_model_free(model); + + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt b/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt new file mode 100644 index 0000000..44e5f72 --- /dev/null +++ b/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-convert-llama2c-to-ggml) +add_executable(${TARGET} convert-llama2c-to-ggml.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/convert-llama2c-to-ggml/README.md b/llama.cpp/examples/convert-llama2c-to-ggml/README.md new file mode 100644 index 0000000..46a42da --- /dev/null +++ b/llama.cpp/examples/convert-llama2c-to-ggml/README.md @@ -0,0 +1,25 @@ +## Convert llama2.c model to ggml + +This example reads weights from project [llama2.c](https://github.com/karpathy/llama2.c) and saves them in ggml compatible format. The vocab that is available in `models/ggml-vocab.bin` is used by default. + +To convert the model first download the models from the [llama2.c](https://github.com/karpathy/llama2.c) repository. + +``` +usage: ./llama-convert-llama2c-to-ggml [options] + +options: + -h, --help show this help message and exit + --copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default 'models/7B/ggml-model-f16.gguf') + --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model + --llama2c-output-model FNAME model path to save the converted llama2.c model (default ak_llama_model.bin') +``` + +An example command using a model from [karpathy/tinyllamas](https://huggingface.co/karpathy/tinyllamas) is as follows: + +`$ ./llama-convert-llama2c-to-ggml --copy-vocab-from-model llama-2-7b-chat.gguf.q2_K.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.gguf.bin` + +Note: The vocabulary for `stories260K.bin` should be its own tokenizer `tok512.bin` found in [karpathy/tinyllamas/stories260K](https://huggingface.co/karpathy/tinyllamas/tree/main/stories260K). + +Now you can use the model with a command like: + +`$ ./llama-cli -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256` diff --git a/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp new file mode 100644 index 0000000..767198a --- /dev/null +++ b/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -0,0 +1,941 @@ +#include "ggml.h" +#include "gguf.h" + +#include "llama.h" +#include "common.h" +#include "log.h" + +#include <unordered_map> +#include <vector> +#include <cassert> +#include <climits> +#include <cstring> +#include <cstdarg> +#include <cinttypes> +#include <ctime> +#include <random> +#include <stdexcept> +#include <sstream> +#include <algorithm> +#include <string> + +// GGUF keys & tensor names. + +#define KV_GENERAL_ARCHITECTURE "general.architecture" +#define KV_GENERAL_NAME "general.name" + +#define KV_TOKENIZER_MODEL "tokenizer.ggml.model" +#define KV_TOKENIZER_LIST "tokenizer.ggml.tokens" +#define KV_TOKENIZER_TOKEN_TYPE "tokenizer.ggml.token_type" +#define KV_TOKENIZER_SCORES "tokenizer.ggml.scores" +#define KV_TOKENIZER_BOS_ID "tokenizer.ggml.bos_token_id" +#define KV_TOKENIZER_EOS_ID "tokenizer.ggml.eos_token_id" +#define KV_TOKENIZER_UNK_ID "tokenizer.ggml.unknown_token_id" +#define KV_TOKENIZER_SEP_ID "tokenizer.ggml.seperator_token_id" +#define KV_TOKENIZER_PAD_ID "tokenizer.ggml.padding_token_id" +#define KV_TOKENIZER_HF_JSON "tokenizer.huggingface.json" + +#define KV_CONTEXT_LENGTH "llama.context_length" +#define KV_EMBEDDING_LENGTH "llama.embedding_length" +#define KV_BLOCK_COUNT "llama.block_count" +#define KV_FEED_FORWARD_LENGTH "llama.feed_forward_length" +#define KV_ATTENTION_HEAD_COUNT "llama.attention.head_count" +#define KV_ATTENTION_HEAD_COUNT_KV "llama.attention.head_count_kv" +#define KV_ATTENTION_LAYERNORM_RMS_EPS "llama.attention.layer_norm_rms_epsilon" +#define KV_ROPE_DIMENSION_COUNT "llama.rope.dimension_count" + +#define TN_TOKEN_EMBD "token_embd.weight" +#define TN_OUTPUT_NORM "output_norm.weight" +#define TN_OUTPUT "output.weight" +#define TN_ATTN_NORM "blk.%d.attn_norm.weight" +#define TN_ATTN_Q "blk.%d.attn_q.weight" +#define TN_ATTN_K "blk.%d.attn_k.weight" +#define TN_ATTN_V "blk.%d.attn_v.weight" +#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight" +#define TN_FFN_NORM "blk.%d.ffn_norm.weight" +#define TN_FFN_GATE "blk.%d.ffn_gate.weight" +#define TN_FFN_DOWN "blk.%d.ffn_down.weight" +#define TN_FFN_UP "blk.%d.ffn_up.weight" + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' +#define LLAMA_FILE_VERSION_GGJT_V3 3 + +#define TOKENIZER_NAME "llama" +#define UNKNOWN_TOKEN_ID 0 +#define BOS_TOKEN_ID 1 +#define EOS_TOKEN_ID 2 + +//////////////////////////////////////// llama2.c model structs and functions to load models, alloc memory etc. +typedef struct { + int dim; // transformer dimension + int hidden_dim; // for ffn layers + int n_layers; // number of layers + int n_heads; // number of query heads + int n_kv_heads; // number of key/value heads (can be < query heads because of multiquery) + int vocab_size; // vocabulary size, usually 256 (byte-level) + int seq_len; // max sequence length +} Config; + +struct TransformerWeights { + // token embedding table + std::vector<float> token_embedding_table; // (vocab_size, dim) + // weights for rmsnorms + std::vector<float> rms_att_weight; // (layer, dim) rmsnorm weights + std::vector<float> rms_ffn_weight; // (layer, dim) + // weights for matmuls + std::vector<float> wq; // (layer, dim, dim) + std::vector<float> wk; // (layer, dim, dim) + std::vector<float> wv; // (layer, dim, dim) + std::vector<float> wo; // (layer, dim, dim) + // weights for ffn + std::vector<float> w1; // (layer, hidden_dim, dim) + std::vector<float> w2; // (layer, dim, hidden_dim) + std::vector<float> w3; // (layer, hidden_dim, dim) + // final rmsnorm + std::vector<float> rms_final_weight; // (dim,) + // freq_cis for RoPE relatively positional embeddings + // std::vector<float> freq_cis_real; // (seq_len, dim/2) + // std::vector<float> freq_cis_imag; // (seq_len, dim/2) + // (optional) classifier weights for the logits, on the last layer + std::vector<float> wcls; +}; + +static void alloc_weights(TransformerWeights * w, const Config * p, bool shared_weights) { + const int n_multiqueries = p->n_kv_heads <= 0 || p->n_kv_heads >= p->n_heads ? 1 : p->n_heads / p->n_kv_heads; + try { + w->token_embedding_table.resize(p->vocab_size * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + + w->rms_att_weight.resize(p->n_layers * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_att_weight\n",__func__,p->n_layers, p->dim, p->n_layers * p->dim); + + w->rms_ffn_weight.resize(p->n_layers * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->rms_ffn_weight\n",__func__,p->n_layers , p->dim, p->n_layers * p->dim); + + w->wq.resize(p->n_layers * p->dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wq\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); + + w->wk.resize(p->n_layers * p->dim * p->dim / n_multiqueries); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wk\n",__func__,p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); + + w->wv.resize(p->n_layers * p->dim * p->dim / n_multiqueries); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wv\n",__func__, p->n_layers, p->dim, p->dim / n_multiqueries, p->n_layers * p->dim * p->dim / n_multiqueries); + + w->wo.resize(p->n_layers * p->dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->wo\n",__func__,p->n_layers, p->dim, p->dim, p->n_layers * p->dim * p->dim); + + w->w1.resize(p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w1\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); + + w->w2.resize(p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w2\n",__func__,p->n_layers, p->dim, p->hidden_dim, p->n_layers * p->hidden_dim * p->dim); + + w->w3.resize(p->n_layers * p->hidden_dim * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] x [%d] = [%d] float space for w->w3\n",__func__,p->n_layers, p->hidden_dim, p->dim, p->n_layers * p->hidden_dim * p->dim); + + w->rms_final_weight.resize(p->dim); + LOG_INF("%s: Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); + + if (shared_weights) { + w->wcls = {}; + } else { + w->wcls.resize(p->vocab_size * p->dim); + LOG_INF("%s: Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + } + } + catch (std::length_error &) { + die("Invalid configuration. Failed to allocate memory for weights"); + } +} + +static int checkpoint_init_weights(TransformerWeights * w, const Config * p, FILE * f, bool shared_weights) { + if (fread(w->token_embedding_table.data(), sizeof(float), w->token_embedding_table.size(), f) != w->token_embedding_table.size()) return 1; + if (fread(w->rms_att_weight.data(), sizeof(float), w->rms_att_weight.size(), f) != w->rms_att_weight.size()) return 1; + if (fread(w->wq.data(), sizeof(float), w->wq.size(), f) != w->wq.size()) return 1; + if (fread(w->wk.data(), sizeof(float), w->wk.size(), f) != w->wk.size()) return 1; + if (fread(w->wv.data(), sizeof(float), w->wv.size(), f) != w->wv.size()) return 1; + if (fread(w->wo.data(), sizeof(float), w->wo.size(), f) != w->wo.size()) return 1; + if (fread(w->rms_ffn_weight.data(), sizeof(float), w->rms_ffn_weight.size(), f) != w->rms_ffn_weight.size()) return 1; + if (fread(w->w1.data(), sizeof(float), w->w1.size(), f) != w->w1.size()) return 1; + if (fread(w->w2.data(), sizeof(float), w->w2.size(), f) != w->w2.size()) return 1; + if (fread(w->w3.data(), sizeof(float), w->w3.size(), f) != w->w3.size()) return 1; + if (fread(w->rms_final_weight.data(), sizeof(float), w->rms_final_weight.size(), f) != w->rms_final_weight.size()) return 1; + + // Skip freq_cis_real & freq_cis_imag + int head_size = p->dim / p->n_heads; + fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR); + + if (!shared_weights && fread(w->wcls.data(), sizeof(float), w->wcls.size(), f) != w->wcls.size()) return 1; + + // Check we didn't forget to read anything + auto curr = ftell(f); + fseek(f, 0, SEEK_END); + auto end = ftell(f); + if (curr != end) { + LOG_ERR("%s: Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", __func__, curr, end); + return 1; + } + + return 0; +} + +static void print_sample_weights(TransformerWeights *w){ + LOG_INF("----- Quick print of first of the weight vales of all the variables\n"); + LOG_INF("%f\n", w->token_embedding_table[0]); + LOG_INF("%f\n", w->rms_att_weight[0]); + LOG_INF("%f\n", w->rms_ffn_weight[0]); + + LOG_INF("%f\n", w->wq[0]); + LOG_INF("%f\n", w->wk[0]); + LOG_INF("%f\n", w->wv[0]); + LOG_INF("%f\n", w->wo[0]); + LOG_INF("%f\n", w->w1[0]); + LOG_INF("%f\n", w->w2[0]); + LOG_INF("%f\n", w->w3[0]); + LOG_INF("%f\n", w->rms_att_weight[0]); + if (!w->wcls.empty()) LOG_INF("%f\n", w->wcls[0]); +} +//////////////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model. + +struct my_llama_vocab { + using id = int32_t; + using token = std::string; + using ttype = llama_token_type; + + struct token_data { + token text; + float score; + ttype type; + }; + + std::unordered_map<token, id> token_to_id; + std::vector<token_data> id_to_token; +}; + +struct my_llama_hparams { + uint32_t n_vocab = 32000; + uint32_t n_ctx = 512; // this is provided as user input? + uint32_t n_embd = 4096; + uint32_t n_ff = 11008; + uint32_t n_mult = 4; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + + bool operator!=(const my_llama_hparams& other) const { + return memcmp(this, &other, sizeof(my_llama_hparams)); + } +}; + +struct my_llama_layer { + // normalization + struct ggml_tensor * attention_norm; + + // attention + struct ggml_tensor * wq; + struct ggml_tensor * wk; + struct ggml_tensor * wv; + struct ggml_tensor * wo; + + // normalization + struct ggml_tensor * ffn_norm; + + // ff + struct ggml_tensor * w1; + struct ggml_tensor * w2; + struct ggml_tensor * w3; +}; + +struct my_llama_model { + struct ggml_context * ctx = NULL; + + std::string name; + + my_llama_hparams hparams; + + struct ggml_tensor * tok_embeddings; + + struct ggml_tensor * norm; + struct ggml_tensor * output; + + std::vector<my_llama_layer> layers; + + uint32_t train_its = 0; + uint32_t train_samples = 0; + uint32_t train_tokens = 0; +}; + +struct train_params { + const char * fn_vocab_model; + const char * fn_llama2c_model; + const char * fn_llama2c_output_model; + const char * fn_train_data; + const char * fn_checkpoint_in; + const char * fn_checkpoint_out; + const char * fn_model_out; + + uint32_t seed; + + int n_ctx; + int n_embd; + int n_mult; + int n_head; + int n_layer; + int n_rotmax; + + int n_threads; + int n_batch; + int n_examples; + int n_predict; + + int print_info_interval; + int print_details_interval; + + bool samples_start_after_nl; + bool use_adam; + bool use_flash; + bool use_scratch; + + // only adam + int warmup; + int cos_decay_steps; + float cos_decay_restart; + float cos_decay_alpha; + + int lbfgs_n_iter; + int adam_n_iter; + float adam_alpha; + float adam_decay; + + int mem_model_gb; + int mem_compute_gb; + int mem_compute0_gb; + int mem_compute1_gb; +}; + +static void print_params(struct my_llama_hparams * params) { + LOG_INF("%s: n_vocab: %u\n", __func__, params->n_vocab); + LOG_INF("%s: n_ctx: %u\n", __func__, params->n_ctx); + LOG_INF("%s: n_embd: %u\n", __func__, params->n_embd); + LOG_INF("%s: n_mult: %u\n", __func__, params->n_mult); + LOG_INF("%s: n_head: %u\n", __func__, params->n_head); + LOG_INF("%s: n_head_kv: %u\n", __func__, params->n_head_kv); + LOG_INF("%s: n_ff: %u\n", __func__, params->n_ff); + LOG_INF("%s: n_layer: %u\n", __func__, params->n_layer); + LOG_INF("%s: n_rot: %u\n", __func__, params->n_rot); +} + +static void print_tensor_info(const struct ggml_context * ctx) { + for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + LOG_INF("%s: Allocating ", __func__); + int64_t total = 1; + int i = 0; + for (; i < ggml_n_dims(t); ++i) { + if (i > 0) { LOG_INF("x "); } + LOG_INF("[%" PRId64 "] ", t->ne[i]); + total *= t->ne[i]; + } + if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); } + LOG_INF("float space for %s\n", ggml_get_name(t)); + } +} + +static void init_model(struct my_llama_model * model) { + const auto & hparams = model->hparams; + + const uint32_t n_embd = hparams.n_embd; + const uint32_t n_layer = hparams.n_layer; + const uint32_t n_vocab = hparams.n_vocab; + + const uint32_t n_multiqueries = hparams.n_head_kv <= 0 || hparams.n_head_kv >= hparams.n_head ? 1 : hparams.n_head / hparams.n_head_kv; + + const uint32_t n_ff = hparams.n_ff; + struct ggml_context * ctx = model->ctx; + + model->train_its = 0; + model->train_samples = 0; + model->train_tokens = 0; + + model->tok_embeddings = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + model->norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + model->output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab); + + ggml_set_name(model->tok_embeddings, "tok_embeddings.weight"); + ggml_set_name(model->norm, "norm.weight"); + ggml_set_name(model->output, "output.weight"); + + model->layers.resize(n_layer); + for (uint32_t i = 0; i < n_layer; ++i) { + auto & layer = model->layers[i]; + + std::string layers_i = "layers." + std::to_string(i); + + layer.attention_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.wq = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + layer.wk = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries); + layer.wv = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd / n_multiqueries); + layer.wo = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_embd); + + layer.ffn_norm = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd); + + layer.w1 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + layer.w2 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd); + layer.w3 = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff); + + ggml_set_name(layer.attention_norm, (layers_i + ".attention_norm.weight").c_str()); + + ggml_set_name(layer.wq, (layers_i + ".attention.wq.weight").c_str()); + ggml_set_name(layer.wk, (layers_i + ".attention.wk.weight").c_str()); + ggml_set_name(layer.wv, (layers_i + ".attention.wv.weight").c_str()); + ggml_set_name(layer.wo, (layers_i + ".attention.wo.weight").c_str()); + + ggml_set_name(layer.ffn_norm, (layers_i + ".ffn_norm.weight").c_str()); + + ggml_format_name(layer.w1, "%s.feed_forward.w1.weight", layers_i.c_str()); + ggml_format_name(layer.w2, "%s.feed_forward.w2.weight", layers_i.c_str()); + ggml_format_name(layer.w3, "%s.feed_forward.w3.weight", layers_i.c_str()); + } + + print_tensor_info(ctx); +} + +static float get_f32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { + float * ptr = (float *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + return *ptr; +} + +static int32_t get_i32_2d(struct ggml_tensor * tensor, int64_t i0, int64_t i1) { + int32_t * ptr = (int32_t *) ((char *) tensor->data + i0*tensor->nb[0] + i1*tensor->nb[1]); + return *ptr; +} + +static void print_row(struct ggml_tensor * probs, int i) { + for (int k = 0; k < probs->ne[0]; ++k) { + float p = get_f32_2d(probs, k, i); + LOG(" %f", p); + } + LOG("\n"); +} + +static void print_matrix(struct ggml_tensor * probs) { + assert(ggml_is_matrix(probs)); + for (int i = 0; i < probs->ne[1]; ++i) { + for (int k = 0; k < probs->ne[0]; ++k) { + float p = get_f32_2d(probs, k, i); + LOG(" %.2f", p); + } + LOG("\n"); + } +} + +struct my_llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + my_llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + die_fmt("fread failed: %s", strerror(errno)); + } + if (ret != 1) { + die("unexpectedly reached end of file"); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + std::float_t read_f32() { + std::float_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector<char> chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + ~my_llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +static bool is_ggml_file(const char * filename) { + my_llama_file file(filename, "rb"); + if (file.size < 4) { + return false; + } + std::string magic = file.read_string(4); + return magic == GGUF_MAGIC; +} + +static std::string llama_escape_whitespaces(const std::string & text) { + std::ostringstream out; + for (char c : text) { + if (c == ' ') out << "\xe2\x96\x81"; + else out << c; + } + return out.str(); +} + +static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) { + if (is_ggml_file(filename)) { + LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename); + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + struct gguf_context * ctx = gguf_init_from_file(filename, params); + GGML_ASSERT(ctx != NULL); + + const int model_idx = gguf_find_key(ctx, KV_TOKENIZER_MODEL); + GGML_ASSERT(model_idx >= 0); + std::string tokenizer_name = gguf_get_val_str(ctx, model_idx); + GGML_ASSERT(tokenizer_name == TOKENIZER_NAME); + + const int token_idx = gguf_find_key(ctx, KV_TOKENIZER_LIST); + GGML_ASSERT(token_idx >= 0); + + const int score_idx = gguf_find_key(ctx, KV_TOKENIZER_SCORES); + GGML_ASSERT(score_idx >= 0); + const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + + const int toktype_idx = gguf_find_key(ctx, KV_TOKENIZER_TOKEN_TYPE); + GGML_ASSERT(toktype_idx >= 0); + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + if (n_vocab != static_cast<uint32_t>(config->vocab_size)) { + die_fmt("vocab size mismatch: (gguf) %u != (llama2c) %d", n_vocab, config->vocab_size); + } + + vocab->id_to_token.resize(n_vocab); + + for (uint32_t i = 0; i < n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + + vocab->token_to_id[word] = i; + + auto & token_data = vocab->id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores[i]; + token_data.type = (llama_token_type) toktypes[i]; + } + ggml_free(ctx_data); + gguf_free(ctx); + } else { + // assume llama2.c vocabulary + LOG_INF("%s: Assuming llama2.c vocabulary since %s is not a gguf file\n", __func__, filename); + my_llama_file file(filename, "rb"); + if (!file.fp) { + die_fmt("%s: %s", strerror(errno), filename); + } + const int n_vocab = config->vocab_size; + /* uint32_t max_token_length = */ file.read_u32(); // unused + vocab->id_to_token.resize(n_vocab); + for (my_llama_vocab::id id=0; id<n_vocab; ++id) { + float_t score = file.read_f32(); + uint32_t len = file.read_u32(); + std::string text = file.read_string(len); + + unsigned char byte_val; + my_llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL; + if (id == UNKNOWN_TOKEN_ID) { + text = "<unk>"; + type = LLAMA_TOKEN_TYPE_UNKNOWN; + } else if (id == BOS_TOKEN_ID) { + text = "<s>"; + type = LLAMA_TOKEN_TYPE_CONTROL; + } else if (id == EOS_TOKEN_ID) { + text = "</s>"; + type = LLAMA_TOKEN_TYPE_CONTROL; + } else if (text.empty()) { + type = LLAMA_TOKEN_TYPE_CONTROL; + } else if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) { + // Text of byte tokens is already in the expected format. + type = LLAMA_TOKEN_TYPE_BYTE; + } else { + type = LLAMA_TOKEN_TYPE_NORMAL; + } + text = llama_escape_whitespaces(text); + + vocab->id_to_token[id].text = text; + vocab->id_to_token[id].score = score; + vocab->id_to_token[id].type = type; + vocab->token_to_id.emplace(text, id); + } + } +} + +static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const float * karpathy_weights) { + int size = 1; + for (int dim = 0; dim < ggml_n_dims(gg_weights); ++dim) { + size *= gg_weights->ne[dim]; + } + for (int ct = 0; ct < size; ++ct) { + int64_t i0 = 0; int64_t i1 = 0; + int64_t i2 = 0; int64_t i3 = 0; + ggml_unravel_index(gg_weights, ct, &i0, &i1, &i2, &i3); + ggml_set_f32_nd(gg_weights, i0, i1, i2, i3, karpathy_weights[ct]); + } +} + +static void save_as_llama_model( + struct my_llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename +) { + // convert AK weights into GG weights one by one. + // w->token_embedding_table -> model->tok_embeddings + // float* -> struct ggml_tensor + convert_weights_ak_to_gg(model->tok_embeddings, w->token_embedding_table.data()); + convert_weights_ak_to_gg(model->output, !w->wcls.empty() ? w->wcls.data() : w->token_embedding_table.data()); + + convert_weights_ak_to_gg(model->norm, w->rms_final_weight.data()); + //print_row(model->norm, 0); + + // for rms-att-weight + int row_length = model->hparams.n_embd; + int n_ff = model->hparams.n_ff; + + const uint32_t n_multiqueries = model->hparams.n_head_kv <= 0 || model->hparams.n_head_kv >= model->hparams.n_head ? 1 : model->hparams.n_head / model->hparams.n_head_kv; + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ + auto & layer = model->layers[i]; + // 1d + convert_weights_ak_to_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]); + convert_weights_ak_to_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]); + + // from 3d matrix layer x dim x dim to 2d matrix dim x dim + convert_weights_ak_to_gg(layer.wq , &w->wq[i*row_length*row_length]); + convert_weights_ak_to_gg(layer.wo , &w->wo[i*row_length*row_length]); + // from 3d matrix layer x dim x dim to 2d matrix dim x dim / n_multiqueries + convert_weights_ak_to_gg(layer.wk , &w->wk[i*row_length*row_length/n_multiqueries]); + convert_weights_ak_to_gg(layer.wv , &w->wv[i*row_length*row_length/n_multiqueries]); + + convert_weights_ak_to_gg(layer.w1 , &w->w1[i*row_length*n_ff]); + convert_weights_ak_to_gg(layer.w2 , &w->w2[i*n_ff*row_length]); + convert_weights_ak_to_gg(layer.w3 , &w->w3[i*row_length*n_ff]); + } + + struct gguf_context * ctx = gguf_init_empty(); + + std::vector<const char*> tokens; + std::vector<float> scores; + std::vector<llama_token_type> token_types; + for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) { + tokens.push_back(token_data.text.c_str()); + scores.push_back(token_data.score); + token_types.push_back(token_data.type); + } + gguf_set_arr_str(ctx, KV_TOKENIZER_LIST, tokens.data(), tokens.size()); + gguf_set_arr_data(ctx, KV_TOKENIZER_SCORES, GGUF_TYPE_FLOAT32, scores.data(), scores.size()); + gguf_set_arr_data(ctx, KV_TOKENIZER_TOKEN_TYPE, GGUF_TYPE_INT32, token_types.data(), token_types.size()); + + gguf_set_val_str(ctx, KV_TOKENIZER_MODEL, TOKENIZER_NAME); + + gguf_set_val_str(ctx, KV_GENERAL_ARCHITECTURE, "llama"); + gguf_set_val_str(ctx, KV_GENERAL_NAME, "llama"); + + // special tokens + gguf_set_val_u32(ctx, KV_TOKENIZER_UNK_ID, UNKNOWN_TOKEN_ID); + gguf_set_val_u32(ctx, KV_TOKENIZER_BOS_ID, BOS_TOKEN_ID); + gguf_set_val_u32(ctx, KV_TOKENIZER_EOS_ID, EOS_TOKEN_ID); + gguf_set_val_u32(ctx, KV_TOKENIZER_SEP_ID, LLAMA_TOKEN_NULL); + gguf_set_val_u32(ctx, KV_TOKENIZER_PAD_ID, LLAMA_TOKEN_NULL); + + gguf_set_val_u32(ctx, KV_CONTEXT_LENGTH, model->hparams.n_ctx); + gguf_set_val_u32(ctx, KV_EMBEDDING_LENGTH, model->hparams.n_embd); + gguf_set_val_u32(ctx, KV_FEED_FORWARD_LENGTH, model->hparams.n_ff); + gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head); + gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT, model->hparams.n_head); + gguf_set_val_u32(ctx, KV_ATTENTION_HEAD_COUNT_KV, model->hparams.n_head_kv); + gguf_set_val_u32(ctx, KV_BLOCK_COUNT, model->hparams.n_layer); + gguf_set_val_u32(ctx, KV_ROPE_DIMENSION_COUNT, model->hparams.n_rot); + gguf_set_val_f32(ctx, KV_ATTENTION_LAYERNORM_RMS_EPS, 1e-5f); + + // write tensors + ggml_set_name(model->tok_embeddings, TN_TOKEN_EMBD); + gguf_add_tensor(ctx, model->tok_embeddings); + + ggml_set_name(model->norm, TN_OUTPUT_NORM); + gguf_add_tensor(ctx, model->norm); + + ggml_set_name(model->output, TN_OUTPUT); + gguf_add_tensor(ctx, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + ggml_format_name(layer.wq, TN_ATTN_Q, i); + gguf_add_tensor(ctx, layer.wq); + + ggml_format_name(layer.wk, TN_ATTN_K, i); + gguf_add_tensor(ctx, layer.wk); + + ggml_format_name(layer.wv, TN_ATTN_V, i); + gguf_add_tensor(ctx, layer.wv); + + ggml_format_name(layer.wo, TN_ATTN_OUTPUT, i); + gguf_add_tensor(ctx, layer.wo); + + ggml_format_name(layer.attention_norm, TN_ATTN_NORM, i); + gguf_add_tensor(ctx, layer.attention_norm); + + ggml_format_name(layer.w1, TN_FFN_GATE, i); + gguf_add_tensor(ctx, layer.w1); + + ggml_format_name(layer.w2, TN_FFN_DOWN, i); + gguf_add_tensor(ctx, layer.w2); + + ggml_format_name(layer.w3, TN_FFN_UP, i); + gguf_add_tensor(ctx, layer.w3); + + ggml_format_name(layer.ffn_norm, TN_FFN_NORM, i); + gguf_add_tensor(ctx, layer.ffn_norm); + } + + gguf_write_to_file(ctx, filename, false); + gguf_free(ctx); +} + +static struct train_params get_default_train_params() { + struct train_params params; + params.fn_vocab_model = "models/7B/ggml-model-f16.gguf"; + params.fn_llama2c_output_model = "ak_llama_model.bin"; + params.fn_train_data = "shakespeare.txt"; + params.fn_checkpoint_in = "checkpoint.bin"; + params.fn_checkpoint_out = "checkpoint.bin"; + params.fn_model_out = "ggml-checkpoint-f32.bin"; + + params.seed = -1; + + params.n_ctx = 128; + params.n_embd = 256; + params.n_mult = 256; + params.n_head = 8; + params.n_layer = 16; + params.n_rotmax = 64; + + params.n_threads = 6; + params.n_batch = 8; + params.n_examples = 8; + params.n_predict = 1024; + + params.print_info_interval = 1; + params.print_details_interval = 2; + + params.samples_start_after_nl = false; + params.use_adam = true; + params.use_flash = false; + params.use_scratch = true; + + // only adam + params.warmup = 100; + params.cos_decay_steps = 1000; + params.cos_decay_restart = 1.1f; + params.cos_decay_alpha = 0.0f; + + params.lbfgs_n_iter = 16; + params.adam_n_iter = 16; + params.adam_alpha = 1e-3f; + params.adam_decay = 1e-3f; + + params.mem_model_gb = 2; + params.mem_compute_gb = 24; + params.mem_compute0_gb = 8; + params.mem_compute1_gb = 2; + + return params; +} + +static void print_usage(int /*argc*/, char ** argv, const struct train_params * params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " --copy-vocab-from-model FNAME path of gguf llama model or llama2.c vocabulary from which to copy vocab (default '%s')\n", params->fn_vocab_model); + fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n"); + fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model); + fprintf(stderr, "\n"); +} + +static bool params_parse(int argc, char ** argv, struct train_params * params) { + bool invalid_param = false; + bool reqd_param_found = false; + std::string arg; + struct train_params default_params = get_default_train_params(); + const std::string arg_prefix = "--"; + + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + if (arg == "--copy-vocab-from-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_vocab_model = argv[i]; + } else if (arg == "--llama2c-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + reqd_param_found = true; + params->fn_llama2c_model = argv[i]; + } else if (arg == "--llama2c-output-model") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_llama2c_output_model = argv[i]; + } else if (arg == "-h" || arg == "--help") { + print_usage(argc, argv, &default_params); + exit(0); + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + print_usage(argc, argv, &default_params); + exit(1); + } + } + if (invalid_param) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + print_usage(argc, argv, &default_params); + exit(1); + } + if (!reqd_param_found){ + fprintf(stderr, "error: please specify a llama2.c .bin file to be converted with argument --llama2c-model\n"); + print_usage(argc, argv, &default_params); + exit(1); + } + + return true; +} + +static std::string basename(const std::string &path) { + size_t pos = path.find_last_of("/\\"); + if (pos == std::string::npos) { + return path; + } + return path.substr(pos + 1); +} + +int main(int argc, char ** argv) { + common_init(); + + struct train_params params = get_default_train_params(); + if (!params_parse(argc, argv, ¶ms)) { + return 1; + } + + Config config; + TransformerWeights weights = {}; + { + LOG_INF("%s: Loading llama2c model from %s\n", __func__, params.fn_llama2c_model); + FILE * file = fopen(params.fn_llama2c_model, "rb"); + if (!file) { + LOG_ERR("%s: Unable to open the checkpoint file %s!\n", __func__, params.fn_llama2c_model); + return 1; + } + // read in the config header + if (fread(&config, sizeof(Config), 1, file) != 1) { + LOG_ERR("%s: Unable to read llama2c config from %s!\n",__func__,params.fn_llama2c_model); + return 1; + } + auto shared_weights = config.vocab_size > 0; + config.vocab_size = abs(config.vocab_size); + + // read in the Transformer weights + alloc_weights(&weights, &config, shared_weights); + if (checkpoint_init_weights(&weights, &config, file, shared_weights)) { + LOG_ERR("%s: Unable to initialize transformer weights from %s!",__func__,params.fn_llama2c_model); + return 1; + } + fclose(file); + } + + struct my_llama_vocab vocab; + load_vocab(params.fn_vocab_model, &config, &vocab); + + struct my_llama_model model; + model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx); + model.hparams.n_ctx = params.n_ctx; + model.hparams.n_embd = config.dim; //params.n_embd; + model.hparams.n_ff = config.hidden_dim; + model.hparams.n_mult = 32;//params.n_mult; + model.hparams.n_head = config.n_heads; //params.n_head; + model.hparams.n_head_kv = config.n_kv_heads; + model.hparams.n_layer = config.n_layers; //params.n_layer; + model.hparams.n_rot = std::min((uint32_t)params.n_rotmax, model.hparams.n_embd / model.hparams.n_head); + + print_params(&model.hparams); + + struct ggml_init_params lcparams; + lcparams.mem_size = 1024ll*1024ll*1024ll*((size_t) params.mem_model_gb); + lcparams.mem_buffer = NULL; + lcparams.no_alloc = false; + + model.ctx = ggml_init(lcparams); + + init_model(&model); + model.name = basename(params.fn_llama2c_model); + save_as_llama_model(&vocab, &model, &weights, params.fn_llama2c_output_model); + + LOG_INF("%s: Saving llama.c model file %s in ggml format at %s\n", __func__, params.fn_llama2c_model, params.fn_llama2c_output_model); + + ggml_free(model.ctx); + return 0; +} diff --git a/llama.cpp/examples/convert_legacy_llama.py b/llama.cpp/examples/convert_legacy_llama.py new file mode 100755 index 0000000..c4ec5c5 --- /dev/null +++ b/llama.cpp/examples/convert_legacy_llama.py @@ -0,0 +1,1462 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import logging +import argparse +import concurrent.futures +import enum +import faulthandler +import functools +import itertools +import json +import math +import mmap +import os +import pickle +import re +import signal +import struct +import sys +import textwrap +import time +import zipfile +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, IO, Iterable, Literal, TypeVar + +import numpy as np + +if 'NO_LOCAL_GGUF' not in os.environ: + # use .parent.parent since we are in "examples" directory + sys.path.insert(1, str(Path(__file__).parent.parent / 'gguf-py')) + +import gguf +from gguf import BaseVocab, Vocab, NoVocab, BpeVocab, SentencePieceVocab, LlamaHfVocab + +if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + +logger = logging.getLogger("convert") + +if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): + faulthandler.register(signal.SIGUSR1) + +NDArray: TypeAlias = 'np.ndarray[Any, Any]' + +ARCH = gguf.MODEL_ARCH.LLAMA + +DEFAULT_CONCURRENCY = 8 + +ADDED_TOKENS_FILE = 'added_tokens.json' +FAST_TOKENIZER_FILE = 'tokenizer.json' + +# +# data types +# + + +@dataclass(frozen=True) +class DataType: + name: str + dtype: np.dtype[Any] + valid_conversions: list[str] + + def elements_to_bytes(self, n_elements: int) -> int: + return n_elements * self.dtype.itemsize + + +@dataclass(frozen=True) +class UnquantizedDataType(DataType): + pass + + +DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) +DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) +DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) +DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) + + +@dataclass(frozen=True) +class QuantizedDataType(DataType): + block_size: int + quantized_dtype: np.dtype[Any] + ggml_type: gguf.GGMLQuantizationType + + def quantize(self, arr: NDArray) -> NDArray: + raise NotImplementedError(f'Quantization for {self.name} not implemented') + + def elements_to_bytes(self, n_elements: int) -> int: + assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + return self.quantized_dtype.itemsize * (n_elements // self.block_size) + + +@dataclass(frozen=True) +class Q8_0QuantizedDataType(QuantizedDataType): + # Mini Q8_0 quantization in Python! + def quantize(self, arr: NDArray) -> NDArray: + assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' + assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + n_blocks = arr.size // self.block_size + blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre + + def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: + d = abs(blocks).max(axis = 1) / np.float32(127) + with np.errstate(divide = 'ignore'): + qs = (blocks / d[:, None]).round() + qs[d == 0] = 0 + yield from zip(d, qs) + return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) + + +DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', + dtype = np.dtype(np.float32), valid_conversions = [], + ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, + quantized_dtype = np.dtype([('d', '<f2'), ('qs', 'i1', (32,))])) + +# Quantized types skipped here because they may also map to np.float32 +NUMPY_TYPE_TO_DATA_TYPE: dict[np.dtype[Any], DataType] = {} +for dt in (DT_BF16, DT_F16, DT_F32, DT_I32): + if dt.dtype in NUMPY_TYPE_TO_DATA_TYPE: + raise ValueError(f'Invalid duplicate data type {dt}') + NUMPY_TYPE_TO_DATA_TYPE[dt.dtype] = dt + +SAFETENSORS_DATA_TYPES: dict[str, DataType] = { + 'BF16': DT_BF16, + 'F16': DT_F16, + 'F32': DT_F32, + 'I32': DT_I32, +} + +# TODO: match this with `llama_ftype` +# TODO: rename to LLAMAFileType +# TODO: move to `gguf.py` + + +class GGMLFileType(enum.IntEnum): + AllF32 = 0 + MostlyF16 = 1 # except 1d tensors + MostlyQ8_0 = 7 # except 1d tensors + + def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: + dt = GGML_FILE_TYPE_TO_DATA_TYPE.get(self) + if dt is None: + raise ValueError(self) + # Convert all 1D tensors to F32. Most of the codebase that takes in 1D tensors only handles F32 tensors, and most of the outputs tensors are F32. + # Also The 1d tensors aren't much of a performance/size issue. So instead of having to have separate F32 and F16 implementations of both, just convert everything to F32 for now. + return dt if len(tensor.shape) > 1 else DT_F32 + + +GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { + GGMLFileType.AllF32 : DT_F32, + GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.MostlyQ8_0: DT_Q8_0, +} + +# +# hparams loading +# + + +@dataclass +class Params: + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None + n_experts_used: int | None = None + f_norm_eps: float | None = None + + rope_scaling_type: gguf.RopeScalingType | None = None + f_rope_freq_base: float | None = None + f_rope_scale: float | None = None + n_ctx_orig: int | None = None + rope_finetuned: bool | None = None + + ftype: GGMLFileType | None = None + + # path to the directory containing the model files + path_model: Path | None = None + + @staticmethod + def guessed(model: LazyModel) -> Params: + # try transformer naming first + n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + + # try transformer naming first + if "model.layers.0.self_attn.q_proj.weight" in model: + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) + elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming + n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + else: + n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + + if n_layer < 1: + msg = """\ + failed to guess 'n_layer'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) + + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed + + # TODO: verify this + n_ff = int(2 * (4 * n_embd) / 3) + n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) + + return Params( + n_vocab = n_vocab, + n_embd = n_embd, + n_layer = n_layer, + n_ctx = -1, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head, + f_norm_eps = 1e-5, + ) + + @staticmethod + def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: + with open(config_path) as f: + config = json.load(f) + + rope_scaling_type = f_rope_scale = n_ctx_orig = rope_finetuned = None + rope_scaling = config.get("rope_scaling") + + if rope_scaling is not None and (typ := rope_scaling.get("type")): + rope_factor = rope_scaling.get("factor") + f_rope_scale = rope_factor + if typ == "linear": + rope_scaling_type = gguf.RopeScalingType.LINEAR + elif typ == "yarn": + rope_scaling_type = gguf.RopeScalingType.YARN + n_ctx_orig = rope_scaling['original_max_position_embeddings'] + rope_finetuned = rope_scaling['finetuned'] + else: + raise NotImplementedError(f'Unknown rope scaling type: {typ}') + + if "max_sequence_length" in config: + n_ctx = config["max_sequence_length"] + elif "max_position_embeddings" in config: + n_ctx = config["max_position_embeddings"] + else: + msg = """\ + failed to guess 'n_ctx'. This model is unknown or unsupported. + Suggestion: provide 'config.json' of the model in the same directory containing model files.""" + raise KeyError(textwrap.dedent(msg)) + + n_experts = None + n_experts_used = None + + if "num_local_experts" in config: + n_experts = config["num_local_experts"] + n_experts_used = config["num_experts_per_tok"] + + return Params( + n_vocab = config["vocab_size"], + n_embd = config["hidden_size"], + n_layer = config["num_hidden_layers"], + n_ctx = n_ctx, + n_ff = config["intermediate_size"], + n_head = (n_head := config["num_attention_heads"]), + n_head_kv = config.get("num_key_value_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["rms_norm_eps"], + f_rope_freq_base = config.get("rope_theta"), + rope_scaling_type = rope_scaling_type, + f_rope_scale = f_rope_scale, + n_ctx_orig = n_ctx_orig, + rope_finetuned = rope_finetuned, + ) + + # LLaMA v2 70B params.json + # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} + @staticmethod + def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: + with open(config_path) as f: + config = json.load(f) + + n_experts = None + n_experts_used = None + f_rope_freq_base = None + n_ff = None + + # hack to determine LLaMA v1 vs v2 vs CodeLlama + if config.get("moe"): + # Mixtral + n_ctx = 32768 + elif config.get("rope_theta") == 1000000: + # CodeLlama + n_ctx = 16384 + elif config["norm_eps"] == 1e-05: + # LLaMA v2 + n_ctx = 4096 + else: + # LLaMA v1 + n_ctx = 2048 + + if "layers.0.feed_forward.w1.weight" in model: + n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] + + if config.get("moe"): + n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] + n_experts = config["moe"]["num_experts"] + n_experts_used = config["moe"]["num_experts_per_tok"] + f_rope_freq_base = 1e6 + + assert n_ff is not None + + return Params( + n_vocab = model["tok_embeddings.weight"].shape[0], + n_embd = config["dim"], + n_layer = config["n_layers"], + n_ctx = n_ctx, + n_ff = n_ff, + n_head = (n_head := config["n_heads"]), + n_head_kv = config.get("n_kv_heads", n_head), + n_experts = n_experts, + n_experts_used = n_experts_used, + f_norm_eps = config["norm_eps"], + f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), + ) + + @staticmethod + def load(model_plus: ModelPlus) -> Params: + hf_config_path = model_plus.paths[0].parent / "config.json" + orig_config_path = model_plus.paths[0].parent / "params.json" + + if hf_config_path.exists(): + params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) + elif orig_config_path.exists(): + params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) + elif model_plus.format != 'none': + params = Params.guessed(model_plus.model) + else: + raise ValueError('Cannot guess params when model format is none') + + params.path_model = model_plus.paths[0].parent + + return params + + +# +# data loading +# TODO: reuse (probably move to gguf.py?) +# + + +def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + +class Tensor(ABC): + ndarray: NDArray + data_type: DataType + + @abstractmethod + def astype(self, data_type: DataType) -> Self: ... + @abstractmethod + def permute(self, n_head: int, n_head_kv: int) -> Self: ... + @abstractmethod + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> Self: ... + @abstractmethod + def part(self, n_part: int) -> Self: ... + @abstractmethod + def to_ggml(self) -> GGMLCompatibleTensor: ... + + +def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: + assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + fp32_arr = bf16_arr.astype(np.uint32) << 16 + return fp32_arr.view(np.float32) + + +class UnquantizedTensor(Tensor): + def __init__(self, ndarray: NDArray): + assert isinstance(ndarray, np.ndarray) + self.ndarray = ndarray + self.data_type = NUMPY_TYPE_TO_DATA_TYPE[ndarray.dtype] + + def astype(self, data_type: DataType) -> UnquantizedTensor: + dtype = data_type.dtype + if self.data_type == DT_BF16: + self.ndarray = bf16_to_fp32(self.ndarray) + return UnquantizedTensor(self.ndarray.astype(dtype)) + + def to_ggml(self) -> Self: + return self + + def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + + def part(self, n_part: int) -> UnquantizedTensor: + r = self.ndarray.shape[0] // 3 + return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) + + def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: + return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) + + +def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: + tensor = lazy_tensor.load() + assert isinstance(tensor, UnquantizedTensor) + + # double-check: + actual_shape = list(tensor.ndarray.shape) + assert actual_shape == lazy_tensor.shape, (actual_shape, lazy_tensor.shape) + if expected_dtype is not None and expected_dtype != tensor.ndarray.dtype: + if convert: + tensor.ndarray = tensor.ndarray.astype(expected_dtype) + else: + raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + + return tensor.ndarray + + +GGMLCompatibleTensor = UnquantizedTensor + + +@dataclass +class LazyTensor: + _load: Callable[[], Tensor] + shape: list[int] + data_type: DataType + description: str + + def load(self) -> Tensor: + ret = self._load() + # Should be okay if it maps to the same numpy type? + assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ + (self.data_type, ret.data_type, self.description) + return ret + + def astype(self, data_type: DataType) -> LazyTensor: + self.validate_conversion_to(data_type) + + def load() -> Tensor: + return self.load().astype(data_type) + return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') + + def validate_conversion_to(self, data_type: DataType) -> None: + if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: + raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + + +LazyModel: TypeAlias = 'dict[str, LazyTensor]' + +ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none'] + +@dataclass +class ModelPlus: + model: LazyModel + paths: list[Path] # Where this was read from. + format: ModelFormat + vocab: BaseVocab | None # For GGML models (which have vocab built in), the vocab. + + +def merge_sharded(models: list[LazyModel]) -> LazyModel: + # Original LLaMA models have each file contain one part of each tensor. + # Use a dict instead of a set to preserve order. + names = {name: None for model in models for name in model} + + def convert(name: str) -> LazyTensor: + lazy_tensors = [model[name] for model in models] + if len(lazy_tensors) == 1: + # only one file; don't go through this procedure since there might + # be quantized tensors + return lazy_tensors[0] + if len(lazy_tensors[0].shape) == 1: + # the tensor is just duplicated in every file + return lazy_tensors[0] + if name.startswith('tok_embeddings.') or \ + name.endswith('.attention.wo.weight') or \ + name.endswith('.feed_forward.w2.weight'): + # split by columns + axis = 1 + else: + # split by rows + axis = 0 + concatenated_shape = list(lazy_tensors[0].shape) + concatenated_shape[axis] = sum(tensor.shape[axis] for tensor in lazy_tensors) + + def load() -> UnquantizedTensor: + ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] + concatenated = np.concatenate(ndarrays, axis=axis) + return UnquantizedTensor(concatenated) + description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' + return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + return {name: convert(name) for name in names} + + +def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: + formats: set[ModelFormat] = set(mp.format for mp in models_plus) + assert len(formats) == 1, "different formats?" + format = formats.pop() + paths = [path for mp in models_plus for path in mp.paths] + # Use the first non-None vocab, if any. + try: + vocab = next(mp.vocab for mp in models_plus if mp.vocab is not None) + except StopIteration: + vocab = None + + if any("model.embed_tokens.weight" in mp.model for mp in models_plus): + # Transformers models put different tensors in different files, but + # don't split individual tensors between files. + model: LazyModel = {} + for mp in models_plus: + model.update(mp.model) + else: + model = merge_sharded([mp.model for mp in models_plus]) + + return ModelPlus(model, paths, format, vocab) + + +def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute(n_head, n_head_kv) + return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + + +def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: + def load() -> Tensor: + return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() + s[0] = s[0] // 3 + return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + + +def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor: + def load() -> Tensor: + tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors] + return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors])) + s = lazy_tensors[0].shape.copy() + s.insert(0, len(lazy_tensors)) + return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors)) + + +# Functionality that simulates `torch.load` but where individual tensors are +# only loaded into memory on demand, not all at once. +# PyTorch can't do this natively as of time of writing: +# - https://github.com/pytorch/pytorch/issues/64327 +# This allows us to de-shard without multiplying RAM usage, and also +# conveniently drops the PyTorch dependency (though we still need numpy). + + +@dataclass +class LazyStorageKind: + data_type: DataType + + +@dataclass +class LazyStorage: + load: Callable[[int, int], NDArray] + kind: LazyStorageKind + description: str + + +class LazyUnpickler(pickle.Unpickler): + def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile): + super().__init__(fp) + self.data_base_path = data_base_path + self.zip_file = zip_file + + def persistent_load(self, pid: Any) -> Any: + assert pid[0] == 'storage' + assert isinstance(pid[1], LazyStorageKind) + data_type = pid[1].data_type + filename_stem = pid[2] + filename = f'{self.data_base_path}/{filename_stem}' + info = self.zip_file.getinfo(filename) + + def load(offset: int, elm_count: int) -> NDArray: + dtype = data_type.dtype + with self.zip_file.open(info) as fp: + fp.seek(offset * dtype.itemsize) + size = elm_count * dtype.itemsize + data = fp.read(size) + assert len(data) == size + return np.frombuffer(data, dtype) + description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + return LazyStorage(load=load, kind=pid[1], description=description) + + @staticmethod + def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, + requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + assert isinstance(storage, LazyStorage) + + def load() -> UnquantizedTensor: + elm_count = stride[0] * size[0] + return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) + description = f'pickled storage_offset={storage_offset} in {storage.description}' + return LazyTensor(load, list(size), storage.kind.data_type, description) + + @staticmethod + def rebuild_from_type_v2(func, new_type, args, state): + return func(*args) + + CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = { + # getattr used here as a workaround for mypy not being smart enough to determine + # the staticmethods have a __func__ attribute. + ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), + ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), + ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), + ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), + ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), + ('torch', 'IntStorage'): LazyStorageKind(DT_I32), + ('torch', 'Tensor'): LazyTensor, + } + + def find_class(self, module: str, name: str) -> Any: + if not module.startswith('torch'): + return super().find_class(module, name) + return self.CLASSES[(module, name)] + + +def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: + zf = zipfile.ZipFile(outer_fp) + pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + assert len(pickle_paths) == 1, pickle_paths + pickle_fp = zf.open(pickle_paths[0], 'r') + unpickler = LazyUnpickler(pickle_fp, + data_base_path=pickle_paths[0][:-4], + zip_file=zf) + model = unpickler.load() + if 'model' in model: model = model['model'] + as_dict = dict(model.items()) + return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + + +def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: + header_size, = struct.unpack('<Q', fp.read(8)) + header: dict[str, dict[str, Any]] = json.loads(fp.read(header_size)) + # Use mmap for the actual data to avoid race conditions with the file offset. + mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ)) + byte_buf = mapped[8 + header_size:] + + def convert(info: dict[str, Any]) -> LazyTensor: + data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + numpy_dtype = data_type.dtype + shape: list[int] = info['shape'] + begin, end = info['data_offsets'] + assert 0 <= begin <= end <= len(byte_buf) + assert end - begin == math.prod(shape) * numpy_dtype.itemsize + buf = byte_buf[begin:end] + + def load() -> UnquantizedTensor: + return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) + description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return LazyTensor(load, shape, data_type, description) + model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} + return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + +def must_read(fp: IO[bytes], length: int) -> bytes: + ret = fp.read(length) + if len(ret) < length: + raise EOFError("unexpectedly reached end of file") + return ret + + +@functools.lru_cache(maxsize=None) +def lazy_load_file(path: Path) -> ModelPlus: + fp = open(path, 'rb') + first8 = fp.read(8) + fp.seek(0) + if first8[:2] == b'PK': + # A zip file, i.e. PyTorch format + return lazy_load_torch_file(fp, path) + elif struct.unpack('<Q', first8)[0] < 16 * 1024 * 1024: + # Probably safetensors + return lazy_load_safetensors_file(fp, path) + else: + raise ValueError(f"unknown format: {path}") + + +In = TypeVar('In') +Out = TypeVar('Out') + + +def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: int | None = None, use_processpool_executor: bool = False) -> Iterable[Out]: + '''Parallel map, but with backpressure. If the caller doesn't call `next` + fast enough, this will stop calling `func` at some point rather than + letting results pile up in memory. Specifically, there is a max of one + output value buffered per thread.''' + if concurrency < 2: + yield from map(func, iterable) + # Not reached. + iterable = iter(iterable) + executor_class: type[ThreadPoolExecutor] | type[ProcessPoolExecutor] + if use_processpool_executor: + executor_class = ProcessPoolExecutor + else: + executor_class = ThreadPoolExecutor + with executor_class(max_workers=max_workers) as executor: + futures: list[concurrent.futures.Future[Out]] = [] + done = False + for _ in range(concurrency): + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + + while futures: + result = futures.pop(0).result() + while not done and len(futures) < concurrency: + try: + futures.append(executor.submit(func, next(iterable))) + except StopIteration: + done = True + break + yield result + + +def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) -> None: + # Handle special case where the model's vocab size is not set + if params.n_vocab == -1: + raise ValueError( + "The model's vocab size is set to -1 in params.json. Please update it manually." + + (f" Maybe {vocab.vocab_size}?" if isinstance(vocab, Vocab) else ""), + ) + if not isinstance(vocab, Vocab): + return # model has no vocab + + # Check for a vocab size mismatch + if params.n_vocab == vocab.vocab_size: + logger.warning("Ignoring added_tokens.json since model matches vocab size without it.") + return + + if pad_vocab and params.n_vocab > vocab.vocab_size: + pad_count = params.n_vocab - vocab.vocab_size + logger.debug( + f"Padding vocab with {pad_count} token(s) - <dummy00001> through <dummy{pad_count:05}>" + ) + for i in range(1, pad_count + 1): + vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1 + vocab.added_tokens_list.append(f"<dummy{i:05}>") + vocab.vocab_size = params.n_vocab + return + + msg = f"Vocab size mismatch (model has {params.n_vocab}, but {vocab.fname_tokenizer} has {vocab.vocab_size})." + if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20: + msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." + if vocab.vocab_size < params.n_vocab: + msg += " Add the --pad-vocab option and try again." + + raise ValueError(msg) + + +class OutputFile: + def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) + + def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None: + # Metadata About The Model And Its Provenence + name = "LLaMA" + if metadata is not None and metadata.name is not None: + name = metadata.name + elif params.path_model is not None: + name = params.path_model.name + elif params.n_ctx == 4096: + # Heuristic detection of LLaMA v2 model + name = "LLaMA v2" + + self.gguf.add_name(name) + + if metadata is not None: + if metadata.author is not None: + self.gguf.add_author(metadata.author) + if metadata.version is not None: + self.gguf.add_version(metadata.version) + if metadata.organization is not None: + self.gguf.add_organization(metadata.organization) + + if metadata.finetune is not None: + self.gguf.add_finetune(metadata.finetune) + if metadata.basename is not None: + self.gguf.add_basename(metadata.basename) + + if metadata.description is not None: + self.gguf.add_description(metadata.description) + if metadata.quantized_by is not None: + self.gguf.add_quantized_by(metadata.quantized_by) + + if metadata.size_label is not None: + self.gguf.add_size_label(metadata.size_label) + + if metadata.license is not None: + self.gguf.add_license(metadata.license) + if metadata.license_name is not None: + self.gguf.add_license_name(metadata.license_name) + if metadata.license_link is not None: + self.gguf.add_license_link(metadata.license_link) + + if metadata.url is not None: + self.gguf.add_url(metadata.url) + if metadata.doi is not None: + self.gguf.add_doi(metadata.doi) + if metadata.uuid is not None: + self.gguf.add_uuid(metadata.uuid) + if metadata.repo_url is not None: + self.gguf.add_repo_url(metadata.repo_url) + + if metadata.source_url is not None: + self.gguf.add_source_url(metadata.source_url) + if metadata.source_doi is not None: + self.gguf.add_source_doi(metadata.source_doi) + if metadata.source_uuid is not None: + self.gguf.add_source_uuid(metadata.source_uuid) + if metadata.source_repo_url is not None: + self.gguf.add_source_repo_url(metadata.source_repo_url) + + if metadata.base_models is not None: + self.gguf.add_base_model_count(len(metadata.base_models)) + for key, base_model_entry in enumerate(metadata.base_models): + if "name" in base_model_entry: + self.gguf.add_base_model_name(key, base_model_entry["name"]) + if "author" in base_model_entry: + self.gguf.add_base_model_author(key, base_model_entry["author"]) + if "version" in base_model_entry: + self.gguf.add_base_model_version(key, base_model_entry["version"]) + if "organization" in base_model_entry: + self.gguf.add_base_model_organization(key, base_model_entry["organization"]) + if "description" in base_model_entry: + self.gguf.add_base_model_description(key, base_model_entry["description"]) + if "url" in base_model_entry: + self.gguf.add_base_model_url(key, base_model_entry["url"]) + if "doi" in base_model_entry: + self.gguf.add_base_model_doi(key, base_model_entry["doi"]) + if "uuid" in base_model_entry: + self.gguf.add_base_model_uuid(key, base_model_entry["uuid"]) + if "repo_url" in base_model_entry: + self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"]) + + if metadata.datasets is not None: + self.gguf.add_dataset_count(len(metadata.datasets)) + for key, dataset_entry in enumerate(metadata.datasets): + if "name" in dataset_entry: + self.gguf.add_dataset_name(key, dataset_entry["name"]) + if "author" in dataset_entry: + self.gguf.add_dataset_author(key, dataset_entry["author"]) + if "version" in dataset_entry: + self.gguf.add_dataset_version(key, dataset_entry["version"]) + if "organization" in dataset_entry: + self.gguf.add_dataset_organization(key, dataset_entry["organization"]) + if "description" in dataset_entry: + self.gguf.add_dataset_description(key, dataset_entry["description"]) + if "url" in dataset_entry: + self.gguf.add_dataset_url(key, dataset_entry["url"]) + if "doi" in dataset_entry: + self.gguf.add_dataset_doi(key, dataset_entry["doi"]) + if "uuid" in dataset_entry: + self.gguf.add_dataset_uuid(key, dataset_entry["uuid"]) + if "repo_url" in dataset_entry: + self.gguf.add_dataset_repo_url(key, dataset_entry["repo_url"]) + + if metadata.tags is not None: + self.gguf.add_tags(metadata.tags) + if metadata.languages is not None: + self.gguf.add_languages(metadata.languages) + + def add_meta_arch(self, params: Params) -> None: + # Metadata About The Neural Architecture Itself + self.gguf.add_vocab_size(params.n_vocab) + self.gguf.add_context_length(params.n_ctx) + self.gguf.add_embedding_length(params.n_embd) + self.gguf.add_block_count(params.n_layer) + self.gguf.add_feed_forward_length(params.n_ff) + self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) + self.gguf.add_head_count (params.n_head) + self.gguf.add_head_count_kv (params.n_head_kv) + + if params.n_experts: + self.gguf.add_expert_count(params.n_experts) + + if params.n_experts_used: + self.gguf.add_expert_used_count(params.n_experts_used) + + if params.f_norm_eps: + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) + else: + raise ValueError('f_norm_eps is None') + + if params.f_rope_freq_base is not None: + self.gguf.add_rope_freq_base(params.f_rope_freq_base) + + if params.rope_scaling_type: + assert params.f_rope_scale is not None + self.gguf.add_rope_scaling_type(params.rope_scaling_type) + self.gguf.add_rope_scaling_factor(params.f_rope_scale) + + if params.n_ctx_orig is not None: + self.gguf.add_rope_scaling_orig_ctx_len(params.n_ctx_orig) + + if params.rope_finetuned is not None: + self.gguf.add_rope_scaling_finetuned(params.rope_finetuned) + + if params.ftype is not None: + self.gguf.add_file_type(params.ftype) + + def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + tokens = [] + scores = [] + toktypes = [] + + # NOTE: `all_tokens` returns the base vocabulary and added tokens + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + assert len(tokens) == vocab.vocab_size + + return tokens, scores, toktypes + + def add_meta_vocab(self, vocab: Vocab) -> None: + # Ensure that tokenizer_model is added to the GGUF model + self.gguf.add_tokenizer_model(vocab.tokenizer_model) + + # Extract model vocabulary for model conversion + tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) + + # Add extracted token information for model conversion + self.gguf.add_token_list(tokens) + self.gguf.add_token_scores(scores) + self.gguf.add_token_types(toktypes) + + def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: + svocab.add_to_gguf(self.gguf) + + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: + n_elements = int(np.prod(tensor.shape)) + raw_dtype = getattr(tensor.data_type, 'ggml_type', None) + data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + data_nbytes = tensor.data_type.elements_to_bytes(n_elements) + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) + + def write_meta(self) -> None: + self.gguf.write_header_to_file() + self.gguf.write_kv_data_to_file() + + def write_tensor_info(self) -> None: + self.gguf.write_ti_data_to_file() + + def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + use_processpool_executor=True, + ) + else: + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + + start = time.time() + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + logger.info( + f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) + self.gguf.write_tensor_data(ndarray) + + def close(self) -> None: + self.gguf.close() + + @staticmethod + def write_vocab_only( + fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_model(params, metadata) + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + + of.write_meta() + + of.close() + + @staticmethod + def do_item(item: tuple[str, LazyTensor]) -> tuple[DataType, NDArray]: + name, lazy_tensor = item + tensor = lazy_tensor.load().to_ggml() + return (lazy_tensor.data_type, tensor.ndarray) + + @staticmethod + def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: + dt, arr = item + if not isinstance(dt, QuantizedDataType): + return arr + return dt.quantize(arr) + + @staticmethod + def write_all( + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + metadata: gguf.Metadata | None = None, + ) -> None: + check_vocab_size(params, vocab, pad_vocab=pad_vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_model(params, metadata) + of.add_meta_arch(params) + if isinstance(vocab, Vocab): + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) + else: # NoVocab + of.gguf.add_tokenizer_model(vocab.tokenizer_model) + + # tensor info + for name, lazy_tensor in model.items(): + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() + + # tensor data + of.write_tensor_data(ftype, model, concurrency) + + of.close() + + +def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: + wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type + + if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): + return GGMLFileType.AllF32 + if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): + return GGMLFileType.MostlyF16 + if output_type_str == "q8_0": + return GGMLFileType.MostlyQ8_0 + + name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + + raise ValueError(f"Unexpected combination of types: {name_to_type}") + + +def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]: + total_params = 0 + shared_params = 0 + expert_params = 0 + + for name, lazy_tensor in tensors: + # We don't need these + if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + continue + + # Got A Tensor + sum_weights_in_tensor: int = 1 + + # Tensor Volume + for dim in lazy_tensor.shape: + sum_weights_in_tensor *= dim + + if ".experts." in name: + if ".experts.0." in name: + expert_params += sum_weights_in_tensor + else: + shared_params += sum_weights_in_tensor + + total_params += sum_weights_in_tensor + + return total_params, shared_params, expert_params + + +def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: + return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items()} + + +def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel: + tmap = gguf.TensorNameMap(ARCH, params.n_layer) + should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) + + tmp = model + + # merge experts into one tensor + if params.n_experts and params.n_experts > 0: + for i_l in range(params.n_layer): + for w in range(1, 4): + experts = [] + for e in range(params.n_experts): + if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model: + experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]) + del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"] + elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model: + experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]) + del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"] + else: + raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight") + tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) + + # HF models permut or pack some of the tensors, so we need to undo that + for i in itertools.count(): + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + logger.debug(f"Permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) + # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + logger.debug(f"Unpacking and permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] + else: + break + + out: LazyModel = {} + for name, lazy_tensor in model.items(): + tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + if name_new is None: + if skip_unknown: + logger.warning(f"Unexpected tensor name: {name} - skipping") + continue + raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)") + + if tensor_type in should_skip: + logger.debug(f"skipping tensor {name_new}") + continue + + logger.debug(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + out[name_new] = lazy_tensor + + return out + + +def nth_multifile_path(path: Path, n: int) -> Path | None: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the nth path in the model. + ''' + # Support the following patterns: + patterns = [ + # - x.00.pth, x.01.pth, etc. + (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. + (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + # x.bin, x.bin.1, etc. + (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + ] + for regex, replacement in patterns: + if re.search(regex, path.name): + new_path = path.with_name(re.sub(regex, replacement, path.name)) + if new_path.exists(): + return new_path + return None + + +def find_multifile_paths(path: Path) -> list[Path]: + '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + the whole list of paths in the model. + ''' + ret: list[Path] = [] + for i in itertools.count(): + nth_path = nth_multifile_path(path, i) + if nth_path is None: + break + ret.append(nth_path) + if not ret: + # No matches. This should only happen if the file was named, e.g., + # foo.0, and there was no file named foo. Oh well, try to process it + # as a single file. + return [path] + return ret + + +def load_some_model(path: Path) -> ModelPlus: + '''Load a model of any supported format.''' + # Be extra-friendly and accept either a file or a directory: + if path.is_dir(): + # Check if it's a set of safetensors files first + globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + # Try the PyTorch patterns too, with lower priority + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + files = [file for glob in globs for file in path.glob(glob)] + if not files: + raise FileNotFoundError(f"Can't find model in directory {path}") + if len(files) > 1: + raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") + path = files[0] + + paths = find_multifile_paths(path) + models_plus: list[ModelPlus] = [] + for path in paths: + logger.info(f"Loading model file {path}") + models_plus.append(lazy_load_file(path)) + + model_plus = merge_multifile_models(models_plus) + return model_plus + + +class VocabFactory: + _VOCAB_CLASSES: list[type[Vocab]] = [SentencePieceVocab, BpeVocab, LlamaHfVocab] + + def __init__(self, path: Path): + self.path = path + + def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab: + load_merges = vocab.name == "bpe" + n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None + return gguf.SpecialVocab( + model_parent_path, + load_merges=load_merges, + special_token_types=None, # Predetermined or passed as a parameter + n_vocab=n_vocab, + ) + + def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: + vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES} + selected_vocabs: dict[str, type[Vocab]] = {} + for vtype in vocab_types: + try: + selected_vocabs[vtype] = vocab_classes[vtype] + except KeyError: + raise ValueError(f"Unsupported vocabulary type {vtype}") from None + + for vtype, cls in selected_vocabs.items(): + try: + vocab = cls(self.path) + break + except FileNotFoundError: + pass # ignore unavailable tokenizers + else: + raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}") + + logger.info(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}") + return vocab + + def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]: + vocab: BaseVocab + if vocab_types is None: + vocab = NoVocab() + else: + vocab = self._create_vocab_by_path(vocab_types) + # FIXME: Respect --vocab-dir? + special_vocab = self._create_special_vocab( + vocab, + model_parent_path, + ) + return vocab, special_vocab + + +def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str: + name = metadata.name if metadata.name is not None else None + basename = metadata.basename if metadata.basename is not None else None + finetune = metadata.finetune if metadata.finetune is not None else None + version = metadata.version if metadata.version is not None else None + size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0) + + output_type = { + GGMLFileType.AllF32: "F32", + GGMLFileType.MostlyF16: "F16", + GGMLFileType.MostlyQ8_0: "Q8_0", + }[file_type] + + return gguf.naming_convention(name, basename, finetune, version, size_label, output_type) + + +def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path: + default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata) + ret = model_paths[0].parent / f"{default_filename}.gguf" + if ret in model_paths: + logger.error( + f"Error: Default output path ({ret}) would overwrite the input. " + "Please explicitly specify a path using --outfile.") + sys.exit(1) + return ret + + +def do_dump_model(model_plus: ModelPlus) -> None: + print(f"model_plus.paths = {model_plus.paths!r}") # noqa: NP100 + print(f"model_plus.format = {model_plus.format!r}") # noqa: NP100 + print(f"model_plus.vocab = {model_plus.vocab!r}") # noqa: NP100 + for name, lazy_tensor in model_plus.model.items(): + print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") # noqa: NP100 + + +def main(args_in: list[str] | None = None) -> None: + output_choices = ["f32", "f16"] + if np.uint32(1) == np.uint32(1).newbyteorder("<"): + # We currently only support Q8_0 output on little endian systems. + output_choices.append("q8_0") + parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab") + parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") + parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) + parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") + parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") + parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") + parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("--metadata", type=Path, help="Specify the path for an authorship metadata override file") + parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name") + parser.add_argument("--model-name", type=str, default=None, help="name of the model") + + args = parser.parse_args(args_in) + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + elif args.dump_single or args.dump or args.get_outfile: + # Avoid printing anything besides the dump output + logging.basicConfig(level=logging.WARNING) + else: + logging.basicConfig(level=logging.INFO) + + model_name = args.model_name + dir_model = args.model + + metadata = gguf.Metadata.load(args.metadata, dir_model, model_name) + + if args.get_outfile: + model_plus = load_some_model(dir_model) + params = Params.load(model_plus) + model = convert_model_names(model_plus.model, params, args.skip_unknown) + model_params_count = per_model_weight_count_estimation(model_plus.model.items()) + ftype = pick_output_type(model, args.outtype) + + if (metadata is None or metadata.name is None) and params.path_model is not None: + metadata.name = params.path_model.name + + print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100 + return + + if args.no_vocab and args.vocab_only: + raise ValueError("--vocab-only does not make sense with --no-vocab") + + if args.dump_single: + model_plus = lazy_load_file(dir_model) + do_dump_model(model_plus) + return + + if not args.vocab_only: + model_plus = load_some_model(dir_model) + else: + model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None) + + if args.dump: + do_dump_model(model_plus) + return + + endianess = gguf.GGUFEndian.LITTLE + if args.big_endian: + endianess = gguf.GGUFEndian.BIG + + params = None + if args.pad_vocab or not args.vocab_only: + params = Params.load(model_plus) + if params.n_ctx == -1: + if args.ctx is None: + msg = """\ + The model doesn't have a context size, and you didn't specify one with --ctx + Please specify one with --ctx: + - LLaMA v1: --ctx 2048 + - LLaMA v2: --ctx 4096""" + parser.error(textwrap.dedent(msg)) + params.n_ctx = args.ctx + + if args.outtype: + params.ftype = { + "f32": GGMLFileType.AllF32, + "f16": GGMLFileType.MostlyF16, + "q8_0": GGMLFileType.MostlyQ8_0, + }[args.outtype] + + logger.info(f"params = {params}") + + model_parent_path = model_plus.paths[0].parent + vocab_path = Path(args.vocab_dir or dir_model or model_parent_path) + vocab_factory = VocabFactory(vocab_path) + vocab_types = None if args.no_vocab else args.vocab_type.split(",") + vocab, special_vocab = vocab_factory.load_vocab(vocab_types, model_parent_path) + + if args.vocab_only: + assert isinstance(vocab, Vocab) + if not args.outfile: + raise ValueError("need --outfile if using --vocab-only") + outfile = args.outfile + if params is None: + params = Params( + n_vocab = vocab.vocab_size, + n_embd = 1, + n_layer = 1, + n_ctx = 1, + n_ff = 1, + n_head = 1, + n_head_kv = 1, + f_norm_eps = 1e-5, + ) + OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, + endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) + logger.info(f"Wrote {outfile}") + return + + if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab: + vocab = model_plus.vocab + + assert params is not None + + if metadata.name is None and params.path_model is not None: + metadata.name = params.path_model.name + + model_params_count = per_model_weight_count_estimation(model_plus.model.items()) + logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})") + + logger.info(f"Vocab info: {vocab}") + logger.info(f"Special vocab info: {special_vocab}") + model = model_plus.model + model = convert_model_names(model, params, args.skip_unknown) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata) + + metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0) + + params.ftype = ftype + logger.info(f"Writing {outfile}, format {ftype}") + + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) + logger.info(f"Wrote {outfile}") + + +if __name__ == '__main__': + main() diff --git a/llama.cpp/examples/debug/CMakeLists.txt b/llama.cpp/examples/debug/CMakeLists.txt new file mode 100644 index 0000000..3459307 --- /dev/null +++ b/llama.cpp/examples/debug/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-debug) +add_executable(${TARGET} debug.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/debug/README.md b/llama.cpp/examples/debug/README.md new file mode 100644 index 0000000..28e00c9 --- /dev/null +++ b/llama.cpp/examples/debug/README.md @@ -0,0 +1,54 @@ +# llama.cpp/examples/debug + +This is a utility intended to help debug a model by registering a callback that +logs GGML operations and tensor data. It can also store the generated logits or +embeddings as well as the prompt and token ids for comparision with the original +model. + +### Usage + +```shell +llama-debug \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --save-logits \ + --verbose +``` +The tensor data is logged as debug and required the --verbose flag. The reason +for this is that while useful for a model with many layers there can be a lot of +output. You can filter the tensor names using the `--tensor-filter` option. + +A recommended approach is to first run without `--verbose` and see if the +generated logits/embeddings are close to the original model. If they are not, +then it might be required to inspect tensor by tensor and in that case it is +useful to enable the `--verbose` flag along with `--tensor-filter` to focus on +specific tensors. + +### Options +This example supports all standard `llama.cpp` options and also accepts the +following options: +```console +$ llama-debug --help +... + +----- example-specific params ----- + +--save-logits save final logits to files for verification (default: false) +--logits-output-dir PATH directory for saving logits output files (default: data) +--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times) +``` + +### Output Files + +When `--save-logits` is enabled, the following files are created in the output +directory: + +* `llamacpp-<model>[-embeddings].bin` - Binary output (logits or embeddings) +* `llamacpp-<model>[-embeddings].txt` - Text output (logits or embeddings, one per line) +* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs +* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison + +These files can be compared against the original model's output to verify the +converted model. diff --git a/llama.cpp/examples/debug/debug.cpp b/llama.cpp/examples/debug/debug.cpp new file mode 100644 index 0000000..88947ac --- /dev/null +++ b/llama.cpp/examples/debug/debug.cpp @@ -0,0 +1,253 @@ +#include "debug.h" +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include <cstdlib> +#include <string> +#include <vector> +#include <filesystem> +#include <fstream> +#include <regex> + +static void print_usage(int /*argc*/, char ** argv) { + const std::string usage_template = R"( + example usage: + + Print tensors: + + {prog} -m model.gguf -p "Hello my name is" --verbose + + The tensors to be printed can be filtered with --tensor-filter option. + + Save logits/embeddings: + + {prog} -m model.gguf -p "Hello my name is" --save-logits + + Add --embedding to save embeddings)" "\n"; + + // Fix the source code indentation above that is introduced by the raw string literal. + std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n"); + usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]); + LOG("%s\n", usage.c_str()); +} + +static bool has_pooling(llama_context * ctx) { + switch (llama_pooling_type(ctx)) { + case LLAMA_POOLING_TYPE_NONE: + case LLAMA_POOLING_TYPE_UNSPECIFIED: + return false; + default: + return true; + } +} + +struct output_data { + float * data_ptr = nullptr; + int data_size = 0; + std::string type_suffix; + std::vector<float> embd_norm; + std::string prompt; + std::vector<llama_token> tokens; + + output_data(llama_context * ctx, const llama_model * model, const common_params & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + tokens = common_tokenize(ctx, params.prompt, add_bos); + prompt = params.prompt; + + if (params.embedding) { + const int n_embd = llama_model_n_embd_out(model); + const bool pooling = has_pooling(ctx); + const int n_embd_count = pooling ? 1 : tokens.size(); + const int n_floats = n_embd * n_embd_count; + + float * embd_raw = pooling ? llama_get_embeddings_seq(ctx, 0) : llama_get_embeddings(ctx); + if (embd_raw == nullptr) { + throw std::runtime_error("failed to get embeddings from the model"); + } + + LOG_DBG("pooling_enabled: %s\n", pooling ? "true" : "false"); + LOG_DBG("n_embd: %d\n", n_embd); + LOG_DBG("n_floats: %d\n", n_floats); + LOG_DBG("n_embd_count: %d\n", n_embd_count); + + data_ptr = embd_raw; + data_size = n_floats; + type_suffix = "-embeddings"; + + if (params.embd_normalize >= 0) { + embd_norm.resize(n_floats); + for (int i = 0; i < n_embd_count; i++) { + common_embd_normalize(embd_raw+i*n_embd, embd_norm.data()+i*n_embd, n_embd, params.embd_normalize); + } + data_ptr = embd_norm.data(); + } + } else { + const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1); + const int n_logits = llama_vocab_n_tokens(vocab); + + data_ptr = const_cast<float*>(logits); + data_size = n_logits; + type_suffix = ""; + } + } +}; + +static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) { + std::filesystem::create_directory(output_dir); + auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix); + + // Save logits/embeddings to binary file. + { + std::filesystem::path filepath{base_path.string() + ".bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open binary output file: " + filepath.string()); + } + file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float)); + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save logits/embeddings to text file. + { + std::filesystem::path filepath{base_path.string() + ".txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open text output file: " + filepath.string()); + } + for (int i = 0; i < output.data_size; i++) { + file << i << ": " << output.data_ptr[i] << '\n'; + } + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save prompt and tokens to text file. + { + std::filesystem::path filepath{base_path.string() + "-prompt.txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open prompt output file: " + filepath.string()); + } + + file << "prompt: " << output.prompt << '\n'; + file << "n_tokens: " << output.tokens.size() << '\n'; + + file << "token ids: "; + for (size_t i = 0; i < output.tokens.size(); i++) { + file << output.tokens[i]; + if (i + 1 < output.tokens.size()) { + file << ", "; + } + } + file << '\n'; + LOG("Prompt saved to %s\n", filepath.c_str()); + } + + // Save token ids to binary file. + { + std::filesystem::path filepath{base_path.string() + "-tokens.bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open tokens binary file: " + filepath.string()); + } + file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token)); + LOG("Tokens saved to %s\n", filepath.c_str()); + } + +} + +static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false"); + LOG("Input prompt: \"%s\"\n", prompt.c_str()); + LOG("Token ids (%zu):\n", tokens.size()); + + for (auto id : tokens) { + std::string piece(128, '\0'); + int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true); + if (n < 0) { + LOG_ERR("failed to convert token %d to piece\n", id); + continue; + } + piece.resize(n); + LOG("%s(%d) ", piece.c_str(), id); + } + LOG("\n"); +} + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + print_tokenized_prompt(ctx, tokens, params.prompt); + + if (params.save_logits) { + output_data output {ctx, model, params}; + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + save_output_data(output, model_name, params.logits_output_dir); + } + + return true; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + base_callback_data cb_data(params, params.tensor_filter); + + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + if (!run(ctx, params)) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/deprecation-warning/README.md b/llama.cpp/examples/deprecation-warning/README.md new file mode 100644 index 0000000..9a1b263 --- /dev/null +++ b/llama.cpp/examples/deprecation-warning/README.md @@ -0,0 +1,49 @@ +# Migration notice for binary filenames + +> [!IMPORTANT] +[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggml-org/llama.cpp/pull/7809) + +This migration was important, but it is a breaking change that may not always be immediately obvious to users. + +Please update all scripts and workflows to use the new binary names. + +| Old Filename | New Filename | +| ---- | ---- | +| main | llama-cli | +| server | llama-server | +| llama-bench | llama-bench | +| embedding | llama-embedding | +| quantize | llama-quantize | +| tokenize | llama-tokenize | +| export-lora | llama-export-lora | +| libllava.a | libllava.a | +| baby-llama | llama-baby-llama | +| batched | llama-batched | +| batched-bench | llama-batched-bench | +| benchmark-matmult | llama-benchmark-matmult | +| convert-llama2c-to-ggml | llama-convert-llama2c-to-ggml | +| eval-callback | llama-eval-callback | +| gbnf-validator | llama-gbnf-validator | +| gguf | llama-gguf | +| gguf-split | llama-gguf-split | +| gritlm | llama-gritlm | +| imatrix | llama-imatrix | +| infill | llama-infill | +| llava-cli | llama-llava-cli | +| lookahead | llama-lookahead | +| lookup | llama-lookup | +| lookup-create | llama-lookup-create | +| lookup-merge | llama-lookup-merge | +| lookup-stats | llama-lookup-stats | +| parallel | llama-parallel | +| passkey | llama-passkey | +| perplexity | llama-perplexity | +| q8dot | llama-q8dot | +| quantize-stats | llama-quantize-stats | +| retrieval | llama-retrieval | +| save-load-state | llama-save-load-state | +| simple | llama-simple | +| speculative | llama-speculative | +| vdot | llama-vdot | +| tests/test-c.o | tests/test-c.o | + diff --git a/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp b/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp new file mode 100644 index 0000000..11f5147 --- /dev/null +++ b/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp @@ -0,0 +1,35 @@ +// Warns users that this filename was deprecated, and provides a link for more information. + +#include <cstdio> +#include <string> +#include <unordered_map> + +// Main +int main(int argc, char** argv) { + std::string filename = "main"; + if (argc >= 1) { + filename = argv[0]; + } + + // Get only the program name from the full path + auto pos = filename.find_last_of("/\\"); + if (pos != std::string::npos) { + filename = filename.substr(pos+1); + } + + // Append "llama-" to the beginning of filename to get the replacemnt filename + auto replacement_filename = "llama-" + filename; + + // The exception is if the filename is "main", then our replacement filename is "llama-cli" + if (filename == "main") { + replacement_filename = "llama-cli"; + } + + fprintf(stdout, "\n"); + fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str()); + fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str()); + fprintf(stdout, " See https://github.com/ggml-org/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n"); + fprintf(stdout, "\n"); + + return EXIT_FAILURE; +} diff --git a/llama.cpp/examples/diffusion/CMakeLists.txt b/llama.cpp/examples/diffusion/CMakeLists.txt new file mode 100644 index 0000000..396549c --- /dev/null +++ b/llama.cpp/examples/diffusion/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-diffusion-cli) +add_executable(${TARGET} diffusion-cli.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/diffusion/README.md b/llama.cpp/examples/diffusion/README.md new file mode 100644 index 0000000..f71d241 --- /dev/null +++ b/llama.cpp/examples/diffusion/README.md @@ -0,0 +1,59 @@ +# Diffusion Text Generation + +This directory contains implementations for Diffusion LLMs (DLLMs) + +More Info: +- https://github.com/ggml-org/llama.cpp/pull/14644 +- https://github.com/ggml-org/llama.cpp/pull/14771 + +## Parameters +The diffusion CLI supports various parameters to control the generation process: + +### Core Diffusion Parameters +- `--diffusion-steps`: Number of diffusion steps (default: 256) +- `--diffusion-algorithm`: Algorithm for token selection + - `0`: ORIGIN - Token will be generated in a purely random order from https://arxiv.org/abs/2107.03006. + - `1`: ENTROPY_BASED - Entropy-based selection + - `2`: MARGIN_BASED - Margin-based selection + - `3`: RANDOM - Random selection + - `4`: CONFIDENCE_BASED - Confidence-based selection (default) + - More documentation here https://github.com/DreamLM/Dream +- `--diffusion-visual`: Enable live visualization during generation + +### Scheduling Parameters +Choose one of the following scheduling methods: + +**Timestep-based scheduling:** +- `--diffusion-eps`: Epsilon value for timestep scheduling (e.g., 0.001) + +**Block-based scheduling:** +- `--diffusion-block-length`: Block size for block-based scheduling (e.g., 32) + +### Sampling Parameters +- `--temp`: Temperature for sampling (0.0 = greedy/deterministic, higher = more random) +- `--top-k`: Top-k filtering for sampling +- `--top-p`: Top-p (nucleus) filtering for sampling +- `--seed`: Random seed for reproducibility + +### Model Parameters +- `-m`: Path to the GGUF model file +- `-p`: Input prompt text +- `-ub`: Maximum sequence length (ubatch size) +- `-c`: Context size +- `-b`: Batch size + +### Examples +#### Dream architechture: +``` +llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual +``` + +#### LLaDA architechture: +``` +llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual +``` + +#### RND1 architecture: +``` +llama-diffusion-cli -m RND1-Base-0910.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-algorithm 1 --diffusion-steps 256 --diffusion-visual --temp 0.5 --diffusion-eps 0.001 +``` diff --git a/llama.cpp/examples/diffusion/diffusion-cli.cpp b/llama.cpp/examples/diffusion/diffusion-cli.cpp new file mode 100644 index 0000000..d50f754 --- /dev/null +++ b/llama.cpp/examples/diffusion/diffusion-cli.cpp @@ -0,0 +1,694 @@ +#include "arg.h" +#include "chat.h" +#include "common.h" +#include "llama.h" +#include "log.h" + +#include <limits.h> + +#include <algorithm> +#include <cmath> +#include <cstring> +#include <limits> +#include <random> +#include <string> +#include <vector> + +enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 }; + +// Unified transfer scheduling methods +enum transfer_schedule { + TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining + BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens +}; + +typedef bool (*diffusion_step_callback_t)(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data); + +struct diffusion_params { + int32_t steps = 0; + float temperature = 0; + llama_token mask_token_id = LLAMA_TOKEN_NULL; + diffusion_step_callback_t step_callback = nullptr; + void * step_callback_user_data = nullptr; + int32_t seed = 0; + bool visual_mode = false; + bool shift_logits = false; // Shift logits by -1 after decode + + float top_p = 0.; + int32_t top_k = 0.; + + diffusion_algorithm algorithm = CONFIDENCE_BASED; + transfer_schedule schedule = TIMESTEP_BASED; + + float cfg_scale = 0.; // Config scale for classifier-free guidance + float eps = 0.; // Timestep scheduling + int32_t block_length = 0; // Block size (for block scheduling) + float alg_temp = 0; // algorithm temperature (0.0 = deterministic) + bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0 + + int32_t max_length = 0; // Maximum sequence length +}; + +struct callback_data { + diffusion_params * diff_params; + const llama_vocab * vocab; + int32_t n_input; +}; + +static float calculate_confidence(const llama_token_data_array & cur_p, + diffusion_algorithm algorithm, + std::mt19937 & rng) { + switch (algorithm) { + case CONFIDENCE_BASED: + return cur_p.data[cur_p.selected].p; // Selected token probability + + case ENTROPY_BASED: + { + float entropy = 0.0f; + const float epsilon = 1e-10f; + for (size_t i = 0; i < cur_p.size; i++) { + float prob = cur_p.data[i].p; + entropy += prob * logf(prob + epsilon); + } + return -entropy; // Higher entropy = lower confidence + } + + case MARGIN_BASED: + return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p; + + case RANDOM: + { + std::uniform_real_distribution<float> uniform(0.0f, 1.0f); + return uniform(rng); // Random confidence + } + + case ORIGIN: + return cur_p.data[cur_p.selected].p; + + default: + return 0.0f; + } +} + +// Unified transfer count calculation function +static int32_t calculate_transfer_count(int32_t step, + int32_t total_steps, + int32_t remaining_masked, + transfer_schedule schedule, + float eps, + const std::vector<int32_t> & num_transfer_tokens = {}) { + switch (schedule) { + case TIMESTEP_BASED: + { + float t = 1.0f - (float) step / total_steps * (1.0f - eps); + float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps); + float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f; + return (int32_t) (remaining_masked * p_transfer); + } + + case BLOCK_BASED: + if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) { + return num_transfer_tokens[step]; + } + return remaining_masked / (total_steps - step); // Fallback + + default: + return remaining_masked / (total_steps - step); + } +} + +static bool diffusion_step_callback(int32_t step, + int32_t total_steps, + const llama_token * tokens, + int32_t n_tokens, + void * user_data) { + (void) user_data; + + callback_data * data = static_cast<callback_data *>(user_data); + + auto print_progress_bar = [](int32_t step, int32_t total_steps) { + int progress_percent = (step * 100) / total_steps; + int progress_bars = (step * 50) / total_steps; + LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%", + step, + total_steps, + std::string(progress_bars, '=').c_str(), + std::string(50 - progress_bars, ' ').c_str(), + progress_percent); + }; + + if (data->diff_params->visual_mode) { + // Visual mode: clear + LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left + + print_progress_bar(step, total_steps); + + LOG_INF("\n"); + + std::string current_text = " "; + + for (int32_t i = data->n_input; i < n_tokens; i++) { + std::string token_str; + if (tokens[i] != llama_vocab_mask(data->vocab)) { + char piece[256]; + int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false); + if (n_chars > 0) { + piece[n_chars] = '\0'; + token_str = piece; + } + } else { + token_str = " "; + } + + current_text += token_str; + } + + LOG_INF("%s\n", current_text.c_str()); + } else { + print_progress_bar(step, total_steps); + } + + return true; +} + +static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) { + if (temperature == 0.0f) { + return; + } + + std::uniform_real_distribution<double> uniform(0.0, 1.0); + for (int32_t i = 0; i < n_vocab; i++) { + double noise = uniform(rng); + // Prevent log(0) + noise = std::max(noise, 1e-20); + double gumbel_noise = std::pow(-std::log(noise), temperature); + logits[i] = std::exp(logits[i]) / gumbel_noise; + } +} + +static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) { + std::vector<int32_t> num_transfer_tokens(steps); + + int32_t base = mask_count / steps; + int32_t remainder = mask_count % steps; + + for (int32_t i = 0; i < steps; i++) { + num_transfer_tokens[i] = base + (i < remainder ? 1 : 0); + } + + return num_transfer_tokens; +} + +static void diffusion_generate(llama_context * ctx, + const llama_token * input_tokens, + llama_token * output_tokens, + int32_t n_input, + const diffusion_params & params, + int32_t & n_generated) { + n_generated = 0; + if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) { + return; + } + + const llama_model * model = llama_get_model(ctx); + + // Initialize with input and pad with mask tokens + std::copy(input_tokens, input_tokens + n_input, output_tokens); + std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id); + + std::mt19937 rng(params.seed); + + llama_set_causal_attn(ctx, false); + + int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model)); + + std::vector<llama_token_data> candidates(n_vocab); + std::vector<llama_token_data> conf_candidates; + conf_candidates.reserve(params.max_length); + std::vector<int32_t> mask_positions; + mask_positions.reserve(params.max_length); + + // Setup sampler chain + struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params()); + if (params.top_k > 0) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k)); + } + if (params.top_p < 1.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1)); + } + if (params.temperature > 0.0f) { + llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature)); + } + llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed)); + + struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed); + + llama_batch batch = llama_batch_init(params.max_length, 0, 1); + batch.n_tokens = params.max_length; + + // Pre-allocate buffers for CFG if needed + int32_t logits_size = n_vocab * params.max_length; + std::vector<float> cond_logits_buffer; + std::vector<llama_token> un_x_buffer; + if (params.cfg_scale > 0.0f) { + cond_logits_buffer.resize(logits_size); + un_x_buffer.resize(params.max_length); + } + + // For block-based processing + std::vector<int32_t> num_transfer_tokens; + int32_t num_blocks = 1; + int32_t steps_per_block = params.steps; + + if (params.schedule == BLOCK_BASED) { + GGML_ASSERT(params.max_length % params.block_length == 0); + num_blocks = params.max_length / params.block_length; + GGML_ASSERT(params.steps % num_blocks == 0); + steps_per_block = params.steps / num_blocks; + } + + std::vector<float> confidence(params.max_length); + + int64_t total_sampling_time = 0; + int64_t total_time = 0; + int64_t time_start = ggml_time_us(); + + for (int block_num = 0; block_num < num_blocks; block_num++) { + int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0; + int32_t block_end = (params.schedule == BLOCK_BASED) ? + std::min(n_input + (block_num + 1) * params.block_length, params.max_length) : + params.max_length; + + // Count masked tokens in current block for block-based processing + if (params.schedule == BLOCK_BASED) { + int32_t block_mask_count = 0; + for (int i = block_start; i < block_end; i++) { + if (output_tokens[i] == params.mask_token_id) { + block_mask_count++; + } + } + num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block); + } + + for (int32_t step = 0; step < steps_per_block; step++) { + int32_t global_step = block_num * steps_per_block + step; + + if (params.step_callback) { + if (!params.step_callback( + global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) { + break; + } + } + + // Setup batch + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = output_tokens[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = 0; + batch.logits[i] = 1; + } + + float * logits = nullptr; + + if (params.cfg_scale > 0.0f) { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate conditional"); + break; + } + float * cond_logits_ptr = llama_get_logits(ctx); + std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float)); + + // Unconditional generation (mask input) + std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin()); + for (int32_t i = 0; i < n_input; i++) { + un_x_buffer[i] = params.mask_token_id; + } + + for (int32_t i = 0; i < params.max_length; i++) { + batch.token[i] = un_x_buffer[i]; + } + ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("Failed to generate unconditional"); + break; + } + float * uncond_logits = llama_get_logits(ctx); + + // Apply CFG + for (int32_t i = 0; i < logits_size; i++) { + cond_logits_buffer[i] = + uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]); + } + logits = cond_logits_buffer.data(); + } else { + int ret = llama_decode(ctx, batch); + if (ret != 0) { + LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret); + break; + } + logits = llama_get_logits(ctx); + } + + if (!logits) { + LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step); + break; + } + + auto get_logits_for_pos = [&](int32_t pos) -> const float * { + if (params.shift_logits) { + return pos == 0 ? logits : logits + (pos - 1) * n_vocab; + } + return logits + (pos) *n_vocab; + }; + + int64_t time_start_sampling = ggml_time_us(); + + mask_positions.clear(); + for (int32_t i = 0; i < params.max_length; i++) { + if (output_tokens[i] == params.mask_token_id) { + // For block-based, only consider current block + if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) { + mask_positions.push_back(i); + } + } + } + + if (mask_positions.empty()) { + break; + } + + if (params.add_gumbel_noise && params.temperature > 0.0f) { + add_gumbel_noise(logits, n_vocab, params.temperature, rng); + } + + if (params.algorithm == ORIGIN) { + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + float p_transfer = (float) transfer_count / mask_positions.size(); + + for (int32_t pos : mask_positions) { + if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) { + const float * pos_logits = get_logits_for_pos(pos); + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].id = token_id; + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + } + + llama_token_data_array cur_p = { + candidates.data(), + (size_t) n_vocab, + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + output_tokens[pos] = cur_p.data[cur_p.selected].id; + } + } + } else { + std::vector<std::pair<float, int32_t>> confidences; + std::vector<llama_token> sampled_tokens(mask_positions.size()); + + for (size_t i = 0; i < mask_positions.size(); i++) { + int32_t pos = mask_positions[i]; + const float * pos_logits = get_logits_for_pos(pos); + + for (int32_t token_id = 0; token_id < n_vocab; token_id++) { + candidates[token_id].logit = pos_logits[token_id]; + candidates[token_id].p = 0.0f; + candidates[token_id].id = token_id; + } + + llama_token_data_array cur_p = { + candidates.data(), + candidates.size(), + -1, + false, + }; + + llama_sampler_apply(sampler, &cur_p); + llama_token sampled_token = cur_p.data[cur_p.selected].id; + + float conf = calculate_confidence(cur_p, params.algorithm, rng); + + sampled_tokens[i] = sampled_token; + confidences.emplace_back(conf, i); + } + + int32_t transfer_count = calculate_transfer_count( + step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens); + + if (transfer_count > 0) { + if (params.alg_temp == 0.0f) { + std::partial_sort(confidences.begin(), + confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()), + confidences.end(), + [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) { + if (a.first != b.first) { + return a.first > b.first; + } + return a.second < b.second; + }); + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + int32_t mask_idx = confidences[i].second; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + } + } else { + conf_candidates.clear(); + for (size_t i = 0; i < confidences.size(); i++) { + float conf_logit = confidences[i].first / params.alg_temp; + conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f }); + } + + llama_token_data_array conf_array = { + conf_candidates.data(), + conf_candidates.size(), + -1, + false, + }; + + for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) { + llama_sampler_apply(dist_sampler, &conf_array); + int32_t selected_idx = conf_array.selected; + int32_t mask_idx = selected_idx; + int32_t pos = mask_positions[mask_idx]; + output_tokens[pos] = sampled_tokens[mask_idx]; + + conf_candidates[selected_idx].p = 0.0f; + conf_array.selected = -1; + } + } + } + } + + int64_t time_end_sampling = ggml_time_us(); + total_sampling_time += time_end_sampling - time_start_sampling; + } + } + + int64_t time_end = ggml_time_us(); + total_time += time_end - time_start; + + LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n", + total_time / 1000.0, + total_time / 1000.0 / params.steps, + total_sampling_time / 1000.0 / params.steps); + + llama_batch_free(batch); + llama_sampler_free(sampler); + llama_sampler_free(dist_sampler); + + n_generated = params.max_length; +} + +static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) { + if (!use_chat_template) { + return prompt; + } + + auto chat_templates = common_chat_templates_init(model, ""); + common_chat_templates_inputs inputs; + common_chat_msg system_msg; + + if (!system_prompt.empty()) { + system_msg.role = "system"; + system_msg.content = system_prompt; + inputs.messages.push_back(system_msg); + } + + common_chat_msg user_msg; + user_msg.role = "user"; + user_msg.content = prompt; + + inputs.messages.push_back(user_msg); + inputs.add_generation_prompt = true; + + auto result = common_chat_templates_apply(chat_templates.get(), inputs); + + return result.prompt; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) { + return 1; + } + + common_init(); + llama_backend_init(); + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = params.n_gpu_layers; + model_params.devices = params.devices.data(); + model_params.use_mmap = params.use_mmap; + model_params.use_direct_io = params.use_direct_io; + model_params.use_mlock = params.use_mlock; + model_params.check_tensors = params.check_tensors; + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + if (!model) { + LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str()); + return 1; + } + + if (!llama_model_is_diffusion(model)) { + LOG_ERR("error: unsupported model for diffusion"); + llama_model_free(model); + return 1; + } + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = params.n_ctx; + ctx_params.n_batch = params.n_batch; + ctx_params.n_ubatch = params.n_ubatch; + ctx_params.flash_attn_type = params.flash_attn_type; + ctx_params.no_perf = params.no_perf; + ctx_params.type_k = params.cache_type_k; + ctx_params.type_v = params.cache_type_v; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + LOG_ERR("error: failed to create context\n"); + llama_model_free(model); + return 1; + } + + llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads); + + const llama_vocab * vocab = llama_model_get_vocab(model); + + std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model); + + std::vector<llama_token> input_tokens = common_tokenize(vocab, + formatted_prompt, + /*add special tokens*/ true, + /*parse special*/ true); + + int n_input = input_tokens.size(); + + if (n_input >= params.n_ctx) { + LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx); + llama_free(ctx); + llama_model_free(model); + return 1; + } + + llama_token mask_token_id = llama_vocab_mask(vocab); + + GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL); + + bool visual_mode = params.diffusion.visual_mode; + + int32_t n_generated = 0; + std::vector<llama_token> output_tokens(params.n_ubatch); + + struct diffusion_params diff_params; + + char shift_logits_str[8]; + if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) { + diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0); + } else { + diff_params.shift_logits = true; + } + + //Use either eps or block length, but not both + GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0)); + + if (params.diffusion.eps) { + diff_params.schedule = TIMESTEP_BASED; + diff_params.eps = params.diffusion.eps; + } else if (params.diffusion.block_length) { + diff_params.schedule = BLOCK_BASED; + diff_params.block_length = params.diffusion.block_length; + } + + diff_params.mask_token_id = mask_token_id; + diff_params.seed = params.sampling.seed; + diff_params.temperature = params.sampling.temp; + diff_params.steps = params.diffusion.steps; + diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm); + diff_params.max_length = params.n_ubatch; + diff_params.top_p = params.sampling.top_p; + diff_params.top_k = params.sampling.top_k; + diff_params.visual_mode = params.diffusion.visual_mode; + diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise; + + diff_params.step_callback = diffusion_step_callback; + callback_data cb_data = { &diff_params, vocab, n_input }; + diff_params.step_callback_user_data = &cb_data; + + const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" }; + const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" }; + const char * alg_name = + (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN"; + const char * sched_name = + (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN"; + + LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps); + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length); + LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name); + LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature); + if (diff_params.schedule == TIMESTEP_BASED) { + LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp); + } + if (diff_params.schedule == BLOCK_BASED) { + LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length); + LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale); + } + + diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated); + + if (n_generated > 0) { + if (visual_mode) { + //clear screen and move cursor to top-left + LOG_INF("\033[2J\033[H"); + } + + output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input); + std::string output_data = common_detokenize(vocab, output_tokens, false); + LOG_INF("\n%s\n", output_data.c_str()); + } else { + LOG_INF("Error: diffusion generation failed\n"); + } + + llama_free(ctx); + llama_model_free(model); + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/embedding/CMakeLists.txt b/llama.cpp/examples/embedding/CMakeLists.txt new file mode 100644 index 0000000..8090403 --- /dev/null +++ b/llama.cpp/examples/embedding/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-embedding) +add_executable(${TARGET} embedding.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/embedding/README.md b/llama.cpp/examples/embedding/README.md new file mode 100644 index 0000000..1684f36 --- /dev/null +++ b/llama.cpp/examples/embedding/README.md @@ -0,0 +1,61 @@ +# llama.cpp/example/embedding + +This example demonstrates generate high-dimensional embedding vector of a given text with llama.cpp. + +## Quick Start + +To get started right away, run the following command, making sure to use the correct path for the model you have: + +### Unix-based systems (Linux, macOS, etc.): + +```bash +./llama-embedding -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>/dev/null +``` + +### Windows: + +```powershell +llama-embedding.exe -m ./path/to/model --pooling mean --log-disable -p "Hello World!" 2>$null +``` + +The above command will output space-separated float values. + +## extra parameters +### --embd-normalize $integer$ +| $integer$ | description | formula | +|-----------|---------------------|---------| +| $-1$ | none | +| $0$ | max absolute int16 | $\Large{{32760 * x_i} \over\max \lvert x_i\rvert}$ +| $1$ | taxicab | $\Large{x_i \over\sum \lvert x_i\rvert}$ +| $2$ | euclidean (default) | $\Large{x_i \over\sqrt{\sum x_i^2}}$ +| $>2$ | p-norm | $\Large{x_i \over\sqrt[p]{\sum \lvert x_i\rvert^p}}$ + +### --embd-output-format $'string'$ +| $'string'$ | description | | +|------------|------------------------------|--| +| '' | same as before | (default) +| 'array' | single embeddings | $[[x_1,...,x_n]]$ +| | multiple embeddings | $[[x_1,...,x_n],[x_1,...,x_n],...,[x_1,...,x_n]]$ +| 'json' | openai style | +| 'json+' | add cosine similarity matrix | +| 'raw' | plain text output | + +### --embd-separator $"string"$ +| $"string"$ | | +|--------------|-| +| "\n" | (default) +| "<#embSep#>" | for example +| "<#sep#>" | other example + +## examples +### Unix-based systems (Linux, macOS, etc.): + +```bash +./llama-embedding -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null +``` + +### Windows: + +```powershell +llama-embedding.exe -p 'Castle<#sep#>Stronghold<#sep#>Dog<#sep#>Cat' --pooling mean --embd-separator '<#sep#>' --embd-normalize 2 --embd-output-format '' -m './path/to/model.gguf' --n-gpu-layers 99 --log-disable 2>/dev/null +``` diff --git a/llama.cpp/examples/embedding/embedding.cpp b/llama.cpp/examples/embedding/embedding.cpp new file mode 100644 index 0000000..d8eaaa2 --- /dev/null +++ b/llama.cpp/examples/embedding/embedding.cpp @@ -0,0 +1,411 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include <ctime> +#include <algorithm> + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") { + std::vector<std::string> lines; + size_t start = 0; + size_t end = s.find(separator); + + while (end != std::string::npos) { + lines.push_back(s.substr(start, end - start)); + start = end + separator.length(); + end = s.find(separator, start); + } + + lines.push_back(s.substr(start)); // Add the last part + + return lines; +} + +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } +} + +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + + // clear previous kv_cache values (irrelevant for embeddings) + llama_memory_clear(llama_get_memory(ctx), true); + + // run model + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); + } + + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + const float * embd = nullptr; + int embd_pos = 0; + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings + embd = llama_get_embeddings_ith(ctx, i); + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd_pos = batch.seq_id[i][0]; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + } + + float * out = output + embd_pos * n_embd_out; + common_embd_normalize(embd, out, n_embd_out, embd_norm); + } +} + +// plain, pipe-friendly output: one embedding per line +static void print_raw_embeddings(const float * emb, + int n_embd_count, + int n_embd, + const llama_model * model, + enum llama_pooling_type pooling_type, + int embd_normalize) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + const bool is_rank = (pooling_type == LLAMA_POOLING_TYPE_RANK); + const int cols = is_rank ? std::min<int>(n_embd, (int) n_cls_out) : n_embd; + + for (int j = 0; j < n_embd_count; ++j) { + for (int i = 0; i < cols; ++i) { + if (embd_normalize == 0) { + LOG("%1.0f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } else { + LOG("%1.7f%s", emb[j * n_embd + i], (i + 1 < cols ? " " : "")); + } + } + LOG("\n"); + } +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) { + return 1; + } + + common_init(); + + params.embedding = true; + + // get max number of sequences per batch + const int n_seq_max = llama_max_parallel_sequences(); + + // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the + // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache + // in order to support any number of prompts + if (params.n_parallel == 1) { + LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); + params.kv_unified = true; + params.n_parallel = n_seq_max; + } + + // utilize the full context + if (params.n_batch < params.n_ctx) { + LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx); + params.n_batch = params.n_ctx; + } + + // for non-causal models, batch size must be equal to ubatch size + if (params.attention_type != LLAMA_ATTENTION_TYPE_CAUSAL) { + params.n_ubatch = params.n_batch; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + + if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + LOG_ERR("%s: computing embeddings in encoder-decoder models is not supported\n", __func__); + return 1; + } + + if (n_ctx > n_ctx_train) { + LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + // split the prompt into lines + std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep); + + // max batch size + const uint64_t n_batch = params.n_batch; + + // get added sep and eos token, if any + const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : ""; + const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : ""; + const char * rerank_prompt = llama_model_chat_template(model, "rerank"); + + // tokenize the prompts and trim + std::vector<std::vector<int32_t>> inputs; + for (const auto & prompt : prompts) { + std::vector<llama_token> inp; + + // split classification pairs and insert expected separator tokens + if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { + std::vector<std::string> pairs = split_lines(prompt, params.cls_sep); + if (rerank_prompt != nullptr) { + const std::string query = pairs[0]; + const std::string doc = pairs[1]; + std::string final_prompt = rerank_prompt; + string_replace_all(final_prompt, "{query}" , query); + string_replace_all(final_prompt, "{document}", doc ); + inp = common_tokenize(vocab, final_prompt, true, true); + } else { + std::string final_prompt; + for (size_t i = 0; i < pairs.size(); i++) { + final_prompt += pairs[i]; + if (i != pairs.size() - 1) { + if (!added_eos_token.empty()) { + final_prompt += added_eos_token; + } + if (!added_sep_token.empty()) { + final_prompt += added_sep_token; + } + } + } + inp = common_tokenize(ctx, final_prompt, true, true); + } + } else { + inp = common_tokenize(ctx, prompt, true, true); + } + if (inp.size() > n_batch) { + LOG_ERR("%s: number of tokens in input line (%lld) exceeds batch size (%lld), increase batch size and re-run\n", + __func__, (long long int) inp.size(), (long long int) n_batch); + return 1; + } + inputs.push_back(inp); + } + + // check if the last token is SEP/EOS + // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' + for (auto & inp : inputs) { + if (inp.empty() || (inp.back() != llama_vocab_sep(vocab) && inp.back() != llama_vocab_eos(vocab))) { + LOG_WRN("%s: last token in the prompt is not SEP or EOS\n", __func__); + LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); + } + } + + // tokenization stats + if (params.verbose_prompt) { + for (int i = 0; i < (int) inputs.size(); i++) { + LOG_INF("%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size()); + for (int j = 0; j < (int) inputs[i].size(); j++) { + LOG("%6d -> '%s'\n", inputs[i][j], common_token_to_piece(ctx, inputs[i][j]).c_str()); + } + LOG("\n\n"); + } + } + + // initialize batch + const int n_prompts = prompts.size(); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + + // count number of embeddings + int n_embd_count = 0; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int k = 0; k < n_prompts; k++) { + n_embd_count += inputs[k].size(); + } + } else { + n_embd_count = n_prompts; + } + + // allocate output + const int n_embd_out = llama_model_n_embd_out(model); + std::vector<float> embeddings(n_embd_count * n_embd_out, 0); + float * emb = embeddings.data(); + + // break into batches + int e = 0; // number of embeddings already stored + int s = 0; // number of prompts in current batch + for (int k = 0; k < n_prompts; k++) { + // clamp to n_batch tokens + auto & inp = inputs[k]; + + const uint64_t n_toks = inp.size(); + + // encode if at capacity + if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) { + float * out = emb + e * n_embd_out; + batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize); + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; + s = 0; + common_batch_clear(batch); + } + + // add to batch + batch_add_seq(batch, inp, s); + s += 1; + } + + // final batch + float * out = emb + e * n_embd_out; + batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize); + + if (params.embd_out.empty()) { + LOG("\n"); + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int j = 0; j < n_embd_count; j++) { + LOG("embedding %d: ", j); + for (int i = 0; i < std::min(3, n_embd_out); i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd_out + i]); + } else { + LOG("%9.6f ", emb[j * n_embd_out + i]); + } + } + LOG(" ... "); + for (int i = n_embd_out - 3; i < n_embd_out; i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd_out + i]); + } else { + LOG("%9.6f ", emb[j * n_embd_out + i]); + } + } + LOG("\n"); + } + } else if (pooling_type == LLAMA_POOLING_TYPE_RANK) { + const uint32_t n_cls_out = llama_model_n_cls_out(model); + std::vector<std::string> cls_out_labels; + + for (uint32_t i = 0; i < n_cls_out; i++) { + const char * label = llama_model_cls_label(model, i); + const std::string label_i(label == nullptr ? "" : label); + cls_out_labels.emplace_back(label_i.empty() ? std::to_string(i) : label_i); + } + + for (int j = 0; j < n_embd_count; j++) { + for (uint32_t i = 0; i < n_cls_out; i++) { + // NOTE: if you change this log - update the tests in ci/run.sh + if (n_cls_out == 1) { + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]); + } else { + LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str()); + } + } + } + } else { + // print the first part of the embeddings or for a single prompt, the full embedding + for (int j = 0; j < n_prompts; j++) { + LOG("embedding %d: ", j); + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) { + if (params.embd_normalize == 0) { + LOG("%6.0f ", emb[j * n_embd_out + i]); + } else { + LOG("%9.6f ", emb[j * n_embd_out + i]); + } + } + LOG("\n"); + } + + // print cosine similarity matrix + if (n_prompts > 1) { + LOG("\n"); + LOG("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + LOG("%6.6s ", prompts[i].c_str()); + } + LOG("\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out); + LOG("%6.2f ", sim); + } + LOG("%1.10s", prompts[i].c_str()); + LOG("\n"); + } + } + } + } + + if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") { + const bool notArray = params.embd_out != "array"; + + LOG(notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "["); + for (int j = 0;;) { // at least one iteration (one prompt) + if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); + LOG("["); + for (int i = 0;;) { // at least one iteration (n_embd > 0) + LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]); + i++; + if (i < n_embd_out) LOG(","); else break; + } + LOG(notArray ? "]\n }" : "]"); + j++; + if (j < n_embd_count) LOG(notArray ? ",\n" : ","); else break; + } + LOG(notArray ? "\n ]" : "]\n"); + + if (params.embd_out == "json+" && n_prompts > 1) { + LOG(",\n \"cosineSimilarity\": [\n"); + for (int i = 0;;) { // at least two iteration (n_embd_count > 1) + LOG(" ["); + for (int j = 0;;) { // at least two iteration (n_embd_count > 1) + float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out); + LOG("%6.2f", sim); + j++; + if (j < n_embd_count) LOG(", "); else break; + } + LOG(" ]"); + i++; + if (i < n_embd_count) LOG(",\n"); else break; + } + LOG("\n ]"); + } + + if (notArray) LOG("\n}\n"); + } else if (params.embd_out == "raw") { + print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize); + } + + LOG("\n"); + llama_perf_context_print(ctx); + + // clean up + llama_batch_free(batch); + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/eval-callback/CMakeLists.txt b/llama.cpp/examples/eval-callback/CMakeLists.txt new file mode 100644 index 0000000..6439690 --- /dev/null +++ b/llama.cpp/examples/eval-callback/CMakeLists.txt @@ -0,0 +1,26 @@ +set(TARGET llama-eval-callback) +add_executable(${TARGET} eval-callback.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +if(LLAMA_BUILD_TESTS) + if(NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x") + set(MODEL_NAME "tinyllamas/stories15M-q4_0.gguf") + set(MODEL_HASH "SHA256=66967fbece6dbe97886593fdbb73589584927e29119ec31f08090732d1861739") + else() + set(MODEL_NAME "tinyllamas/stories15M-be.Q4_0.gguf") + set(MODEL_HASH "SHA256=9aec857937849d976f30397e97eb1cabb53eb9dcb1ce4611ba8247fb5f44c65d") + endif() + set(MODEL_DEST "${CMAKE_BINARY_DIR}/${MODEL_NAME}") + set(TEST_TARGET test-eval-callback) + add_test(NAME ${TEST_TARGET}-download-model COMMAND ${CMAKE_COMMAND} + -DDEST=${MODEL_DEST} + -DNAME=${MODEL_NAME} + -DHASH=${MODEL_HASH} + -P ${CMAKE_SOURCE_DIR}/cmake/download-models.cmake + ) + set_tests_properties(${TEST_TARGET}-download-model PROPERTIES FIXTURES_SETUP ${TEST_TARGET}-download-model) + add_test(NAME ${TEST_TARGET} COMMAND llama-eval-callback -m "${MODEL_DEST}" --prompt hello --seed 42 -ngl 0) + set_tests_properties(${TEST_TARGET} PROPERTIES FIXTURES_REQUIRED ${TEST_TARGET}-download-model) +endif() diff --git a/llama.cpp/examples/eval-callback/README.md b/llama.cpp/examples/eval-callback/README.md new file mode 100644 index 0000000..63a57ad --- /dev/null +++ b/llama.cpp/examples/eval-callback/README.md @@ -0,0 +1,95 @@ +# llama.cpp/examples/eval-callback + +A simple example which demonstrates how to use callback during the inference. +It simply prints to the console all operations and tensor data. + +Usage: + +```shell +llama-eval-callback \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --seed 42 \ + -ngl 33 +``` + +Will print: + +```shell +llm_load_tensors: offloaded 33/33 layers to GPU +... +llama_new_context_with_model: n_ctx = 512 +... +llama_new_context_with_model: CUDA0 compute buffer size = 105.00 MiB +llama_new_context_with_model: CUDA_Host compute buffer size = 6.01 MiB +llama_new_context_with_model: graph nodes = 1225 +llama_new_context_with_model: graph splits = 2 +ggml_debug: inp_embd = (f32) GET_ROWS(token_embd.weight{2560, 51200, 1, 1}, inp_tokens{1, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.0181, 0.0272, 0.0272, ...], + ], + ] +ggml_debug: norm-0 = (f32) NORM(CUDA0#inp_embd#0{2560, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -0.6989, 1.0636, 1.0636, ...], + ], + ] +ggml_debug: norm_w-0 = (f32) MUL(norm-0{2560, 1, 1, 1}, blk.0.attn_norm.weight{2560, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.1800, 0.2817, 0.2632, ...], + ], + ] +ggml_debug: attn_norm-0 = (f32) ADD(norm_w-0{2560, 1, 1, 1}, blk.0.attn_norm.bias{2560, 1, 1, 1}}) = {2560, 1, 1, 1} + [ + [ + [ -0.1863, 0.2970, 0.2604, ...], + ], + ] +ggml_debug: wqkv-0 = (f32) MUL_MAT(blk.0.attn_qkv.weight{2560, 7680, 1, 1}, attn_norm-0{2560, 1, 1, 1}}) = {7680, 1, 1, 1} + [ + [ + [ -1.1238, 1.2876, -1.8086, ...], + ], + ] +ggml_debug: bqkv-0 = (f32) ADD(wqkv-0{7680, 1, 1, 1}, blk.0.attn_qkv.bias{7680, 1, 1, 1}}) = {7680, 1, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + ], + ] +ggml_debug: bqkv-0 (view) = (f32) VIEW(bqkv-0{7680, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + ], + ] +ggml_debug: Qcur-0 = (f32) CONT(bqkv-0 (view){2560, 1, 1, 1}, }) = {2560, 1, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + ], + ] +ggml_debug: Qcur-0 (reshaped) = (f32) RESHAPE(Qcur-0{2560, 1, 1, 1}, }) = {80, 32, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + [ -0.3608, 0.5076, -1.8866, ...], + [ 1.7643, 0.0273, -2.1065, ...], + ... + ], + ] +ggml_debug: Qcur-0 = (f32) ROPE(Qcur-0 (reshaped){80, 32, 1, 1}, CUDA0#inp_pos#0{1, 1, 1, 1}}) = {80, 32, 1, 1} + [ + [ + [ -1.1135, 1.4604, -1.9226, ...], + [ -0.3608, 0.5076, -1.8866, ...], + [ 1.7643, 0.0273, -2.1065, ...], + ... + ], + ] +``` diff --git a/llama.cpp/examples/eval-callback/eval-callback.cpp b/llama.cpp/examples/eval-callback/eval-callback.cpp new file mode 100644 index 0000000..bd58734 --- /dev/null +++ b/llama.cpp/examples/eval-callback/eval-callback.cpp @@ -0,0 +1,80 @@ +#include "arg.h" +#include "common.h" +#include "debug.h" +#include "log.h" +#include "llama.h" +#include "llama-cpp.h" +#include <string> +#include <vector> + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + return true; +} + +int main(int argc, char ** argv) { + base_callback_data cb_data; + + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + // pass the callback to the backend scheduler + // it will be executed for each node during the graph computation + params.cb_eval = common_debug_cb_eval<false>; + params.cb_eval_user_data = &cb_data; + params.warmup = false; + + // init + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + bool OK = run(ctx, params); + if (!OK) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/gen-docs/CMakeLists.txt b/llama.cpp/examples/gen-docs/CMakeLists.txt new file mode 100644 index 0000000..25de0af --- /dev/null +++ b/llama.cpp/examples/gen-docs/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-gen-docs) +add_executable(${TARGET} gen-docs.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/gen-docs/gen-docs.cpp b/llama.cpp/examples/gen-docs/gen-docs.cpp new file mode 100644 index 0000000..0aa33e8 --- /dev/null +++ b/llama.cpp/examples/gen-docs/gen-docs.cpp @@ -0,0 +1,142 @@ +#include "arg.h" +#include "common.h" + +#include <fstream> +#include <sstream> +#include <string> + +// Export usage message (-h) to markdown format +// Automatically update the markdown docs + +#define HELP_START_MARKER "<!-- HELP_START -->" +#define HELP_END_MARKER "<!-- HELP_END -->" +#define NOTE_MESSAGE "<!-- IMPORTANT: The list below is auto-generated by llama-gen-docs; do NOT modify it manually -->" + +struct md_file { + llama_example ex; + std::string fname; + std::string specific_section_header; +}; + +std::vector<md_file> md_files = { + {LLAMA_EXAMPLE_CLI, "tools/cli/README.md", "CLI-specific params"}, + {LLAMA_EXAMPLE_COMPLETION, "tools/completion/README.md", "Completion-specific params"}, + {LLAMA_EXAMPLE_SERVER, "tools/server/README.md", "Server-specific params"}, +}; + +static void write_table_header(std::ostringstream & ss) { + ss << "| Argument | Explanation |\n"; + ss << "| -------- | ----------- |\n"; +} + +static void write_table_entry(std::ostringstream & ss, const common_arg & opt) { + ss << "| `"; + // args + auto all_args = opt.get_args(); + for (const auto & arg : all_args) { + if (arg == all_args.front()) { + ss << arg; + if (all_args.size() > 1) ss << ", "; + } else { + ss << arg << (arg != all_args.back() ? ", " : ""); + } + } + // value hint + if (opt.value_hint) { + std::string md_value_hint(opt.value_hint); + string_replace_all(md_value_hint, "|", "\\|"); + ss << " " << md_value_hint; + } + if (opt.value_hint_2) { + std::string md_value_hint_2(opt.value_hint_2); + string_replace_all(md_value_hint_2, "|", "\\|"); + ss << " " << md_value_hint_2; + } + // help text + std::string md_help(opt.help); + md_help = string_strip(md_help); + string_replace_all(md_help, "\n", "<br/>"); + string_replace_all(md_help, "|", "\\|"); + ss << "` | " << md_help << " |\n"; +} + +static void write_table(std::ostringstream & ss, std::vector<common_arg *> & opts) { + write_table_header(ss); + for (const auto & opt : opts) { + write_table_entry(ss, *opt); + } +} + +static void write_help(std::ostringstream & ss, const md_file & md) { + common_params params; + auto ctx_arg = common_params_parser_init(params, md.ex); + + std::vector<common_arg *> common_options; + std::vector<common_arg *> sparam_options; + std::vector<common_arg *> specific_options; + for (auto & opt : ctx_arg.options) { + // in case multiple LLAMA_EXAMPLE_* are set, we prioritize the LLAMA_EXAMPLE_* matching current example + if (opt.is_sparam) { + sparam_options.push_back(&opt); + } else if (opt.in_example(ctx_arg.ex)) { + specific_options.push_back(&opt); + } else { + common_options.push_back(&opt); + } + } + + ss << HELP_START_MARKER << "\n\n"; + + ss << NOTE_MESSAGE << "\n\n"; + + ss << "### Common params\n\n"; + write_table(ss, common_options); + ss << "\n\n### Sampling params\n\n"; + write_table(ss, sparam_options); + ss << "\n\n### " << md.specific_section_header << "\n\n"; + write_table(ss, specific_options); + + ss << "\n" << HELP_END_MARKER; +} + +int main(int, char **) { + for (const auto & md : md_files) { + std::ifstream infile(md.fname); + if (!infile.is_open()) { + fprintf(stderr, "failed to open file '%s' for reading\n", md.fname.c_str()); + return 1; + } + + std::ostringstream ss; + ss << infile.rdbuf(); + infile.close(); + + std::string content = ss.str(); + + size_t help_start = content.find(HELP_START_MARKER); + size_t help_end = content.find(HELP_END_MARKER); + + if (help_start == std::string::npos || help_end == std::string::npos || help_end <= help_start) { + fprintf(stderr, "failed to find help markers in file '%s'\n", md.fname.c_str()); + return 1; + } + + std::ostringstream new_help_ss; + write_help(new_help_ss, md); + std::string new_help = new_help_ss.str(); + + content = content.substr(0, help_start) + new_help + content.substr(help_end + strlen(HELP_END_MARKER)); + + std::ofstream outfile(md.fname); + if (!outfile.is_open()) { + fprintf(stderr, "failed to open file '%s' for writing\n", md.fname.c_str()); + return 1; + } + outfile << content; + outfile.close(); + + printf("Updated help in '%s'\n", md.fname.c_str()); + } + + return 0; +} diff --git a/llama.cpp/examples/gguf-hash/CMakeLists.txt b/llama.cpp/examples/gguf-hash/CMakeLists.txt new file mode 100644 index 0000000..15c5c68 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/CMakeLists.txt @@ -0,0 +1,22 @@ +set(TARGET llama-gguf-hash) +add_executable(${TARGET} gguf-hash.cpp) +install(TARGETS ${TARGET} RUNTIME) + +# clibs dependencies +include_directories(deps/) + +add_library(xxhash OBJECT deps/xxhash/xxhash.c deps/xxhash/xxhash.h) +target_link_libraries(${TARGET} PRIVATE xxhash) + +add_library(sha1 OBJECT deps/sha1/sha1.c deps/sha1/sha1.h) +target_link_libraries(${TARGET} PRIVATE sha1) +if (NOT MSVC) + # disable warnings in 3rd party code + target_compile_options(sha1 PRIVATE -w) +endif() + +add_library(sha256 OBJECT deps/sha256/sha256.c deps/sha256/sha256.h) +target_link_libraries(${TARGET} PRIVATE sha256) + +target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/gguf-hash/README.md b/llama.cpp/examples/gguf-hash/README.md new file mode 100644 index 0000000..9871651 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/README.md @@ -0,0 +1,206 @@ + +# llama-gguf-hash + +CLI to hash GGUF files to detect difference on a per model and per tensor level. + +**Command line options:** + +- `--help`: display help message +- `--xxh64`: use xhash 64bit hash mode (default) +- `--sha1`: use sha1 +- `--uuid`: use uuid +- `--sha256`: use sha256 +- `--all`: use all hash +- `--no-layer`: exclude per layer hash +- `--uuid`: generate UUIDv5 ID +- `-c`, `--check <manifest>`: verify against a manifest + +## About + +While most POSIX systems already have hash checking programs like sha256sum, it +is designed to check entire files. This is not ideal for our purpose if we want +to check for consistency of the tensor data even if the metadata content of the +gguf KV store has been updated. + +This program is designed to hash a gguf tensor payload on a 'per tensor layer' +in addition to a 'entire tensor model' hash. The intent is that the entire +tensor layer can be checked first but if there is any detected inconsistencies, +then the per tensor hash can be used to narrow down the specific tensor layer +that has inconsistencies. + +For Maintainers: +- Detection of tensor inconsistency during development and automated tests + - This is served by xxh64 which is fast + - This is also served by having per tensor layer to assist in narrowing down + the location of the faulty tensor layer + - This is also served by sha1 which is much slower but more widely supported + +For Model Creators: +- Optional consistent UUID generation based on model tensor content + - This is served by UUIDv5 which is useful for databases keys + - llama.cpp UUIDv5 Namespace: `ef001206-dadc-5f6d-a15f-3359e577d4e5` + - Made via UUIDv5 URL namespace of `en.wikipedia.org/wiki/Llama.cpp` + +For Model Users: +- Assurance of tensor layer integrity even if metadata was updated + - This is served by sha256 which is still considered very secure as of 2024 + +### Design Note + +- The default behavior of this program if no arguments is provided is to hash + using xxhash's xxh32 mode because it is very fast and is primarily targeted + towards maintainers who may want to use this in automated tests. +- xxhash support xxh32 and xxh128 for 32bit hash and 128bit hash respectively + however we picked 64bit xxhash as most computers are 64bit as of 2024 and thus + would have a better affinity to calculating hash that is 64bit in size. + +## Compile Example + +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Debug -DLLAMA_FATAL_WARNINGS=ON +make -C build clean +make -C build llama-gguf-hash VERBOSE=1 +./build/bin/llama-gguf-hash test.gguf +./build/bin/llama-gguf-hash --xxh64 test.gguf +./build/bin/llama-gguf-hash --sha1 test.gguf +./build/bin/llama-gguf-hash --uuid test.gguf +./build/bin/llama-gguf-hash --sha256 test.gguf +``` + +## Generation and Verification Example + +To generate we may use this command + +```bash +./llama-gguf-hash --all test.gguf > test.gguf.manifest +``` + +Which would generate a manifest that looks like below, which contains multiple hash type and per tensor layer hashes as well +(This excludes UUID as that is an ID not a hash) + +```bash +xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 +sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 +sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 +xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 +sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 +sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 +xxh64 a0af5d700049693b test.gguf:tensor_2 +sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 +sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 +xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 +sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 +sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 +xxh64 1257733306b7992d test.gguf:tensor_4 +sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 +sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 +xxh64 d238d16ba4711e58 test.gguf:tensor_5 +sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 +sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 +xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 +sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 +sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 +xxh64 c22021c29854f093 test.gguf:tensor_7 +sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 +sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 +xxh64 936df61f5d64261f test.gguf:tensor_8 +sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 +sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 +xxh64 93fd20c64421c081 test.gguf:tensor_9 +sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 +sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 +xxh64 5a54d3aad816f302 test.gguf +sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf +sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf +``` + +We can then use the normal check command which will by default check for the highest security strength hash and verify against that: + +```bash +$ ./llama-gguf-hash --check test.gguf.manifest test.gguf +manifest test.gguf.manifest sha256 sha1 xxh64 +sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok +sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok +sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok +sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok +sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok +sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok +sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok +sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok +sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok +sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok +sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok + +Verification results for test.gguf.manifest - Success +``` + +Or we may explicitly ask for a faster hash like: + +```bash +$ ./llama-gguf-hash --check test.gguf.manifest --xxh64 test.gguf +manifest test.gguf.manifest sha256 sha1 xxh64 +xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok +xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok +xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok +xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok +xxh64 1257733306b7992d test.gguf:tensor_4 - Ok +xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok +xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok +xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok +xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok +xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok +xxh64 5a54d3aad816f302 test.gguf - Ok + +Verification results for test.gguf.manifest - Success +``` + +Or maybe we want to just check that all the hash is valid: + +```bash +$./llama-gguf-hash --check test.gguf.manifest --all test.gguf.manifest +manifest test.gguf.manifest sha256 sha1 xxh64 +xxh64 f66e9cd66a4396a0 test.gguf:tensor_0 - Ok +sha1 59f79ecefd8125a996fdf419239051a7e99e5f20 test.gguf:tensor_0 - Ok +sha256 c0510d38fa060c46265e0160a85c7243096b01dd31c2f355bdbb5516b20de1bd test.gguf:tensor_0 - Ok +xxh64 7d3a1f9ac04d0537 test.gguf:tensor_1 - Ok +sha1 4765f592eacf096df4628ba59476af94d767080a test.gguf:tensor_1 - Ok +sha256 8514cbcc73692a2c56bd7a33a022edd5ff819614bd23b19915d7224387f397a7 test.gguf:tensor_1 - Ok +xxh64 a0af5d700049693b test.gguf:tensor_2 - Ok +sha1 25cbfbad4513cc348e2c95ebdee69d6ff2fd8753 test.gguf:tensor_2 - Ok +sha256 947e6b36e20f2cc95e1d2ce1c1669d813d574657ac6b5ac5196158d454d35180 test.gguf:tensor_2 - Ok +xxh64 e83fddf559d7b6a6 test.gguf:tensor_3 - Ok +sha1 a9cba73e2d90f2ee3dae2548caa42bef3fe6a96c test.gguf:tensor_3 - Ok +sha256 423b044e016d8ac73c39f23f60bf01bedef5ecb03c0230accd824c91fe86f1a1 test.gguf:tensor_3 - Ok +xxh64 1257733306b7992d test.gguf:tensor_4 - Ok +sha1 d7bc61db93bb685ce9d598da89717c66729b7543 test.gguf:tensor_4 - Ok +sha256 79737cb3912d4201384cf7f16a1a37ff7823f23ea796cb205b6ca361ab9e3ebf test.gguf:tensor_4 - Ok +xxh64 d238d16ba4711e58 test.gguf:tensor_5 - Ok +sha1 0706566c198fe1072f37e0a5135b4b5f23654c52 test.gguf:tensor_5 - Ok +sha256 60949be8298eced0ecdde64487643d018407bd261691e061d9e9c3dbc9fd358b test.gguf:tensor_5 - Ok +xxh64 3fbc3b65ab8c7f39 test.gguf:tensor_6 - Ok +sha1 73922a0727226a409049f6fc3172a52219ca6f00 test.gguf:tensor_6 - Ok +sha256 574f4c46ff384a3b9a225eb955d2a871847a2e8b3fa59387a8252832e92ef7b0 test.gguf:tensor_6 - Ok +xxh64 c22021c29854f093 test.gguf:tensor_7 - Ok +sha1 efc39cece6a951188fc41e354c73bbfe6813d447 test.gguf:tensor_7 - Ok +sha256 4c0410cd3c500f078ae5b21e8dc9eb79e29112713b2ab58a882f82a3868d4d75 test.gguf:tensor_7 - Ok +xxh64 936df61f5d64261f test.gguf:tensor_8 - Ok +sha1 c2490296d789a4f34398a337fed8377d943d9f06 test.gguf:tensor_8 - Ok +sha256 c4401313feeba0261275c3b25bd2d8fe40ce04e0f440c2980ed0e9674c30ff01 test.gguf:tensor_8 - Ok +xxh64 93fd20c64421c081 test.gguf:tensor_9 - Ok +sha1 7047ce1e78437a6884337a3751c7ee0421918a65 test.gguf:tensor_9 - Ok +sha256 23d57cf0d7a6e90b0b3616b41300e0cd354781e812add854a5f95aa55f2bc514 test.gguf:tensor_9 - Ok +xxh64 5a54d3aad816f302 test.gguf - Ok +sha1 d15be52c4ff213e823cb6dd13af7ee2f978e7042 test.gguf - Ok +sha256 7dd641b32f59b60dbd4b5420c4b0f6321ccf48f58f6ae201a3dbc4a58a27c6e4 test.gguf - Ok + +Verification results for test.gguf.manifest - Success +``` + + +## Crypto/Hash Libraries Used + +These micro c libraries dependencies was installed via the [clib c package manager](https://github.com/clibs) + +- https://github.com/Cyan4973/xxHash +- https://github.com/clibs/sha1/ +- https://github.com/jb55/sha256.c diff --git a/llama.cpp/examples/gguf-hash/deps/rotate-bits/package.json b/llama.cpp/examples/gguf-hash/deps/rotate-bits/package.json new file mode 100644 index 0000000..74c0bef --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/rotate-bits/package.json @@ -0,0 +1,13 @@ +{ + "name": "rotate-bits", + "version": "0.1.1", + "repo": "jb55/rotate-bits.h", + "description": "rotate bits", + "keywords": ["rotl", "rotr"], + "src": ["rotate-bits.h"], + "license": "Public Domain", + "development": { + "thlorenz/tap.c": "*" + } +} + diff --git a/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h b/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h new file mode 100644 index 0000000..75c4881 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/rotate-bits/rotate-bits.h @@ -0,0 +1,46 @@ + + +#ifndef __ROTATE_DEFS_H +#define __ROTATE_DEFS_H + +#ifdef _MSC_VER + +#include <stdlib.h> + +#define ROTL32(v, n) _rotl((v), (n)) +#define ROTL64(v, n) _rotl64((v), (n)) + +#define ROTR32(v, n) _rotr((v), (n)) +#define ROTR64(v, n) _rotr64((v), (n)) + +#else + +#include <stdint.h> + +#define U8V(v) ((uint8_t)(v) & 0xFFU) +#define U16V(v) ((uint16_t)(v) & 0xFFFFU) +#define U32V(v) ((uint32_t)(v) & 0xFFFFFFFFU) +#define U64V(v) ((uint64_t)(v) & 0xFFFFFFFFFFFFFFFFU) + +#define ROTL32(v, n) \ + (U32V((uint32_t)(v) << (n)) | ((uint32_t)(v) >> (32 - (n)))) + +// tests fail if we don't have this cast... +#define ROTL64(v, n) \ + (U64V((uint64_t)(v) << (n)) | ((uint64_t)(v) >> (64 - (n)))) + +#define ROTR32(v, n) ROTL32(v, 32 - (n)) +#define ROTR64(v, n) ROTL64(v, 64 - (n)) + +#endif + +#define ROTL8(v, n) \ + (U8V((uint8_t)(v) << (n)) | ((uint8_t)(v) >> (8 - (n)))) + +#define ROTL16(v, n) \ + (U16V((uint16_t)(v) << (n)) | ((uint16_t)(v) >> (16 - (n)))) + +#define ROTR8(v, n) ROTL8(v, 8 - (n)) +#define ROTR16(v, n) ROTL16(v, 16 - (n)) + +#endif diff --git a/llama.cpp/examples/gguf-hash/deps/sha1/package.json b/llama.cpp/examples/gguf-hash/deps/sha1/package.json new file mode 100644 index 0000000..6a5843d --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/sha1/package.json @@ -0,0 +1,9 @@ +{ + "name": "sha1", + "version": "0.0.1", + "repo": "clibs/sha1", + "description": "sha1 hash algorithm", + "keywords": ["sha1", "hash"], + "license": "public domain", + "src": ["sha1.c", "sha1.h"] +} diff --git a/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c b/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c new file mode 100644 index 0000000..76cd6ca --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/sha1/sha1.c @@ -0,0 +1,295 @@ +/* +SHA-1 in C +By Steve Reid <steve@edmweb.com> +100% Public Domain + +Test Vectors (from FIPS PUB 180-1) +"abc" + A9993E36 4706816A BA3E2571 7850C26C 9CD0D89D +"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq" + 84983E44 1C3BD26E BAAE4AA1 F95129E5 E54670F1 +A million repetitions of "a" + 34AA973C D4C4DAA4 F61EEB2B DBAD2731 6534016F +*/ + +/* #define LITTLE_ENDIAN * This should be #define'd already, if true. */ +/* #define SHA1HANDSOFF * Copies data before messing with it. */ + +#define SHA1HANDSOFF + +#include <stdio.h> +#include <string.h> + +/* for uint32_t */ +#include <stdint.h> + +#include "sha1.h" + + +#define rol(value, bits) (((value) << (bits)) | ((value) >> (32 - (bits)))) + +/* blk0() and blk() perform the initial expand. */ +/* I got the idea of expanding during the round function from SSLeay */ +#if BYTE_ORDER == LITTLE_ENDIAN +#define blk0(i) (block->l[i] = (rol(block->l[i],24)&0xFF00FF00) \ + |(rol(block->l[i],8)&0x00FF00FF)) +#elif BYTE_ORDER == BIG_ENDIAN +#define blk0(i) block->l[i] +#else +#error "Endianness not defined!" +#endif +#define blk(i) (block->l[i&15] = rol(block->l[(i+13)&15]^block->l[(i+8)&15] \ + ^block->l[(i+2)&15]^block->l[i&15],1)) + +/* (R0+R1), R2, R3, R4 are the different operations used in SHA1 */ +#define R0(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk0(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R1(v,w,x,y,z,i) z+=((w&(x^y))^y)+blk(i)+0x5A827999+rol(v,5);w=rol(w,30); +#define R2(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0x6ED9EBA1+rol(v,5);w=rol(w,30); +#define R3(v,w,x,y,z,i) z+=(((w|x)&y)|(w&x))+blk(i)+0x8F1BBCDC+rol(v,5);w=rol(w,30); +#define R4(v,w,x,y,z,i) z+=(w^x^y)+blk(i)+0xCA62C1D6+rol(v,5);w=rol(w,30); + + +/* Hash a single 512-bit block. This is the core of the algorithm. */ + +void SHA1Transform( + uint32_t state[5], + const unsigned char buffer[64] +) +{ + uint32_t a, b, c, d, e; + + typedef union + { + unsigned char c[64]; + uint32_t l[16]; + } CHAR64LONG16; + +#ifdef SHA1HANDSOFF + CHAR64LONG16 block[1]; /* use array to appear as a pointer */ + + memcpy(block, buffer, 64); +#else + /* The following had better never be used because it causes the + * pointer-to-const buffer to be cast into a pointer to non-const. + * And the result is written through. I threw a "const" in, hoping + * this will cause a diagnostic. + */ + CHAR64LONG16 *block = (const CHAR64LONG16 *) buffer; +#endif + /* Copy context->state[] to working vars */ + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + /* 4 rounds of 20 operations each. Loop unrolled. */ + R0(a, b, c, d, e, 0); + R0(e, a, b, c, d, 1); + R0(d, e, a, b, c, 2); + R0(c, d, e, a, b, 3); + R0(b, c, d, e, a, 4); + R0(a, b, c, d, e, 5); + R0(e, a, b, c, d, 6); + R0(d, e, a, b, c, 7); + R0(c, d, e, a, b, 8); + R0(b, c, d, e, a, 9); + R0(a, b, c, d, e, 10); + R0(e, a, b, c, d, 11); + R0(d, e, a, b, c, 12); + R0(c, d, e, a, b, 13); + R0(b, c, d, e, a, 14); + R0(a, b, c, d, e, 15); + R1(e, a, b, c, d, 16); + R1(d, e, a, b, c, 17); + R1(c, d, e, a, b, 18); + R1(b, c, d, e, a, 19); + R2(a, b, c, d, e, 20); + R2(e, a, b, c, d, 21); + R2(d, e, a, b, c, 22); + R2(c, d, e, a, b, 23); + R2(b, c, d, e, a, 24); + R2(a, b, c, d, e, 25); + R2(e, a, b, c, d, 26); + R2(d, e, a, b, c, 27); + R2(c, d, e, a, b, 28); + R2(b, c, d, e, a, 29); + R2(a, b, c, d, e, 30); + R2(e, a, b, c, d, 31); + R2(d, e, a, b, c, 32); + R2(c, d, e, a, b, 33); + R2(b, c, d, e, a, 34); + R2(a, b, c, d, e, 35); + R2(e, a, b, c, d, 36); + R2(d, e, a, b, c, 37); + R2(c, d, e, a, b, 38); + R2(b, c, d, e, a, 39); + R3(a, b, c, d, e, 40); + R3(e, a, b, c, d, 41); + R3(d, e, a, b, c, 42); + R3(c, d, e, a, b, 43); + R3(b, c, d, e, a, 44); + R3(a, b, c, d, e, 45); + R3(e, a, b, c, d, 46); + R3(d, e, a, b, c, 47); + R3(c, d, e, a, b, 48); + R3(b, c, d, e, a, 49); + R3(a, b, c, d, e, 50); + R3(e, a, b, c, d, 51); + R3(d, e, a, b, c, 52); + R3(c, d, e, a, b, 53); + R3(b, c, d, e, a, 54); + R3(a, b, c, d, e, 55); + R3(e, a, b, c, d, 56); + R3(d, e, a, b, c, 57); + R3(c, d, e, a, b, 58); + R3(b, c, d, e, a, 59); + R4(a, b, c, d, e, 60); + R4(e, a, b, c, d, 61); + R4(d, e, a, b, c, 62); + R4(c, d, e, a, b, 63); + R4(b, c, d, e, a, 64); + R4(a, b, c, d, e, 65); + R4(e, a, b, c, d, 66); + R4(d, e, a, b, c, 67); + R4(c, d, e, a, b, 68); + R4(b, c, d, e, a, 69); + R4(a, b, c, d, e, 70); + R4(e, a, b, c, d, 71); + R4(d, e, a, b, c, 72); + R4(c, d, e, a, b, 73); + R4(b, c, d, e, a, 74); + R4(a, b, c, d, e, 75); + R4(e, a, b, c, d, 76); + R4(d, e, a, b, c, 77); + R4(c, d, e, a, b, 78); + R4(b, c, d, e, a, 79); + /* Add the working vars back into context.state[] */ + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + /* Wipe variables */ + a = b = c = d = e = 0; +#ifdef SHA1HANDSOFF + memset(block, '\0', sizeof(block)); +#endif +} + + +/* SHA1Init - Initialize new context */ + +void SHA1Init( + SHA1_CTX * context +) +{ + /* SHA1 initialization constants */ + context->state[0] = 0x67452301; + context->state[1] = 0xEFCDAB89; + context->state[2] = 0x98BADCFE; + context->state[3] = 0x10325476; + context->state[4] = 0xC3D2E1F0; + context->count[0] = context->count[1] = 0; +} + + +/* Run your data through this. */ + +void SHA1Update( + SHA1_CTX * context, + const unsigned char *data, + uint32_t len +) +{ + uint32_t i; + + uint32_t j; + + j = context->count[0]; + if ((context->count[0] += len << 3) < j) + context->count[1]++; + context->count[1] += (len >> 29); + j = (j >> 3) & 63; + if ((j + len) > 63) + { + memcpy(&context->buffer[j], data, (i = 64 - j)); + SHA1Transform(context->state, context->buffer); + for (; i + 63 < len; i += 64) + { + SHA1Transform(context->state, &data[i]); + } + j = 0; + } + else + i = 0; + memcpy(&context->buffer[j], &data[i], len - i); +} + + +/* Add padding and return the message digest. */ + +void SHA1Final( + unsigned char digest[20], + SHA1_CTX * context +) +{ + unsigned i; + + unsigned char finalcount[8]; + + unsigned char c; + +#if 0 /* untested "improvement" by DHR */ + /* Convert context->count to a sequence of bytes + * in finalcount. Second element first, but + * big-endian order within element. + * But we do it all backwards. + */ + unsigned char *fcp = &finalcount[8]; + + for (i = 0; i < 2; i++) + { + uint32_t t = context->count[i]; + + int j; + + for (j = 0; j < 4; t >>= 8, j++) + *--fcp = (unsigned char) t} +#else + for (i = 0; i < 8; i++) + { + finalcount[i] = (unsigned char) ((context->count[(i >= 4 ? 0 : 1)] >> ((3 - (i & 3)) * 8)) & 255); /* Endian independent */ + } +#endif + c = 0200; + SHA1Update(context, &c, 1); + while ((context->count[0] & 504) != 448) + { + c = 0000; + SHA1Update(context, &c, 1); + } + SHA1Update(context, finalcount, 8); /* Should cause a SHA1Transform() */ + for (i = 0; i < 20; i++) + { + digest[i] = (unsigned char) + ((context->state[i >> 2] >> ((3 - (i & 3)) * 8)) & 255); + } + /* Wipe variables */ + memset(context, '\0', sizeof(*context)); + memset(&finalcount, '\0', sizeof(finalcount)); +} + +void SHA1( + char *hash_out, + const char *str, + uint32_t len) +{ + SHA1_CTX ctx; + unsigned int ii; + + SHA1Init(&ctx); + for (ii=0; ii<len; ii+=1) + SHA1Update(&ctx, (const unsigned char*)str + ii, 1); + SHA1Final((unsigned char *)hash_out, &ctx); +} + diff --git a/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h b/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h new file mode 100644 index 0000000..f492009 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/sha1/sha1.h @@ -0,0 +1,52 @@ +#ifndef SHA1_H +#define SHA1_H + +/* + SHA-1 in C + By Steve Reid <steve@edmweb.com> + 100% Public Domain + */ + +#include "stdint.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +typedef struct +{ + uint32_t state[5]; + uint32_t count[2]; + unsigned char buffer[64]; +} SHA1_CTX; + +void SHA1Transform( + uint32_t state[5], + const unsigned char buffer[64] + ); + +void SHA1Init( + SHA1_CTX * context + ); + +void SHA1Update( + SHA1_CTX * context, + const unsigned char *data, + uint32_t len + ); + +void SHA1Final( + unsigned char digest[20], + SHA1_CTX * context + ); + +void SHA1( + char *hash_out, + const char *str, + uint32_t len); + +#if defined(__cplusplus) +} +#endif + +#endif /* SHA1_H */ diff --git a/llama.cpp/examples/gguf-hash/deps/sha256/package.json b/llama.cpp/examples/gguf-hash/deps/sha256/package.json new file mode 100644 index 0000000..b92a041 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/sha256/package.json @@ -0,0 +1,15 @@ +{ + "name": "sha256", + "version": "0.0.2", + "repo": "jb55/sha256.c", + "description": "sha256 in c", + "keywords": ["sha256", "sha2"], + "src": ["sha256.c", "sha256.h"], + "dependencies": { + "jb55/rotate-bits.h": "0.1.1" + }, + "development": { + "thlorenz/tap.c": "*" + } +} + diff --git a/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c b/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c new file mode 100644 index 0000000..a7a87ae --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/sha256/sha256.c @@ -0,0 +1,221 @@ +/* Crypto/Sha256.c -- SHA-256 Hash +2010-06-11 : Igor Pavlov : Public domain +This code is based on public domain code from Wei Dai's Crypto++ library. */ + +#include "rotate-bits/rotate-bits.h" +#include "sha256.h" + +/* define it for speed optimization */ +#define _SHA256_UNROLL +#define _SHA256_UNROLL2 + +void +sha256_init(sha256_t *p) +{ + p->state[0] = 0x6a09e667; + p->state[1] = 0xbb67ae85; + p->state[2] = 0x3c6ef372; + p->state[3] = 0xa54ff53a; + p->state[4] = 0x510e527f; + p->state[5] = 0x9b05688c; + p->state[6] = 0x1f83d9ab; + p->state[7] = 0x5be0cd19; + p->count = 0; +} + +#define S0(x) (ROTR32(x, 2) ^ ROTR32(x,13) ^ ROTR32(x, 22)) +#define S1(x) (ROTR32(x, 6) ^ ROTR32(x,11) ^ ROTR32(x, 25)) +#define s0(x) (ROTR32(x, 7) ^ ROTR32(x,18) ^ (x >> 3)) +#define s1(x) (ROTR32(x,17) ^ ROTR32(x,19) ^ (x >> 10)) + +#define blk0(i) (W[i] = data[i]) +#define blk2(i) (W[i&15] += s1(W[(i-2)&15]) + W[(i-7)&15] + s0(W[(i-15)&15])) + +#define Ch(x,y,z) (z^(x&(y^z))) +#define Maj(x,y,z) ((x&y)|(z&(x|y))) + +#define a(i) T[(0-(i))&7] +#define b(i) T[(1-(i))&7] +#define c(i) T[(2-(i))&7] +#define d(i) T[(3-(i))&7] +#define e(i) T[(4-(i))&7] +#define f(i) T[(5-(i))&7] +#define g(i) T[(6-(i))&7] +#define h(i) T[(7-(i))&7] + + +#ifdef _SHA256_UNROLL2 + +#define R(a,b,c,d,e,f,g,h, i) h += S1(e) + Ch(e,f,g) + K[i+j] + (j?blk2(i):blk0(i));\ + d += h; h += S0(a) + Maj(a, b, c) + +#define RX_8(i) \ + R(a,b,c,d,e,f,g,h, i); \ + R(h,a,b,c,d,e,f,g, (i+1)); \ + R(g,h,a,b,c,d,e,f, (i+2)); \ + R(f,g,h,a,b,c,d,e, (i+3)); \ + R(e,f,g,h,a,b,c,d, (i+4)); \ + R(d,e,f,g,h,a,b,c, (i+5)); \ + R(c,d,e,f,g,h,a,b, (i+6)); \ + R(b,c,d,e,f,g,h,a, (i+7)) + +#else + +#define R(i) h(i) += S1(e(i)) + Ch(e(i),f(i),g(i)) + K[i+j] + (j?blk2(i):blk0(i));\ + d(i) += h(i); h(i) += S0(a(i)) + Maj(a(i), b(i), c(i)) + +#ifdef _SHA256_UNROLL + +#define RX_8(i) R(i+0); R(i+1); R(i+2); R(i+3); R(i+4); R(i+5); R(i+6); R(i+7); + +#endif + +#endif + +static const uint32_t K[64] = { + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, + 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, + 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, + 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, + 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, + 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, + 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, + 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, + 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2 +}; + +static void +sha256_transform(uint32_t *state, const uint32_t *data) +{ + uint32_t W[16] = {0}; + unsigned j; + #ifdef _SHA256_UNROLL2 + uint32_t a,b,c,d,e,f,g,h; + a = state[0]; + b = state[1]; + c = state[2]; + d = state[3]; + e = state[4]; + f = state[5]; + g = state[6]; + h = state[7]; + #else + uint32_t T[8]; + for (j = 0; j < 8; j++) + T[j] = state[j]; + #endif + + for (j = 0; j < 64; j += 16) + { + #if defined(_SHA256_UNROLL) || defined(_SHA256_UNROLL2) + RX_8(0); RX_8(8); + #else + unsigned i; + for (i = 0; i < 16; i++) { R(i); } + #endif + } + + #ifdef _SHA256_UNROLL2 + state[0] += a; + state[1] += b; + state[2] += c; + state[3] += d; + state[4] += e; + state[5] += f; + state[6] += g; + state[7] += h; + #else + for (j = 0; j < 8; j++) + state[j] += T[j]; + #endif + + /* Wipe variables */ + /* memset(W, 0, sizeof(W)); */ + /* memset(T, 0, sizeof(T)); */ +} + +#undef S0 +#undef S1 +#undef s0 +#undef s1 + +static void +sha256_write_byte_block(sha256_t *p) +{ + uint32_t data32[16]; + unsigned i; + for (i = 0; i < 16; i++) + data32[i] = + ((uint32_t)(p->buffer[i * 4 ]) << 24) + + ((uint32_t)(p->buffer[i * 4 + 1]) << 16) + + ((uint32_t)(p->buffer[i * 4 + 2]) << 8) + + ((uint32_t)(p->buffer[i * 4 + 3])); + sha256_transform(p->state, data32); +} + + +void +sha256_hash(unsigned char *buf, const unsigned char *data, size_t size) +{ + sha256_t hash; + sha256_init(&hash); + sha256_update(&hash, data, size); + sha256_final(&hash, buf); +} + + +void +sha256_update(sha256_t *p, const unsigned char *data, size_t size) +{ + uint32_t curBufferPos = (uint32_t)p->count & 0x3F; + while (size > 0) + { + p->buffer[curBufferPos++] = *data++; + p->count++; + size--; + if (curBufferPos == 64) + { + curBufferPos = 0; + sha256_write_byte_block(p); + } + } +} + + +void +sha256_final(sha256_t *p, unsigned char *digest) +{ + uint64_t lenInBits = (p->count << 3); + uint32_t curBufferPos = (uint32_t)p->count & 0x3F; + unsigned i; + p->buffer[curBufferPos++] = 0x80; + while (curBufferPos != (64 - 8)) + { + curBufferPos &= 0x3F; + if (curBufferPos == 0) + sha256_write_byte_block(p); + p->buffer[curBufferPos++] = 0; + } + for (i = 0; i < 8; i++) + { + p->buffer[curBufferPos++] = (unsigned char)(lenInBits >> 56); + lenInBits <<= 8; + } + sha256_write_byte_block(p); + + for (i = 0; i < 8; i++) + { + *digest++ = (unsigned char)(p->state[i] >> 24); + *digest++ = (unsigned char)(p->state[i] >> 16); + *digest++ = (unsigned char)(p->state[i] >> 8); + *digest++ = (unsigned char)(p->state[i]); + } + sha256_init(p); +} diff --git a/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h b/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h new file mode 100644 index 0000000..21657e6 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/sha256/sha256.h @@ -0,0 +1,24 @@ +/* Sha256.h -- SHA-256 Hash +2010-06-11 : Igor Pavlov : Public domain */ + +#ifndef __CRYPTO_SHA256_H +#define __CRYPTO_SHA256_H + +#include <stdlib.h> +#include <stdint.h> + +#define SHA256_DIGEST_SIZE 32 + +typedef struct sha256_t +{ + uint32_t state[8]; + uint64_t count; + unsigned char buffer[64]; +} sha256_t; + +void sha256_init(sha256_t *p); +void sha256_update(sha256_t *p, const unsigned char *data, size_t size); +void sha256_final(sha256_t *p, unsigned char *digest); +void sha256_hash(unsigned char *buf, const unsigned char *data, size_t size); + +#endif diff --git a/llama.cpp/examples/gguf-hash/deps/xxhash/clib.json b/llama.cpp/examples/gguf-hash/deps/xxhash/clib.json new file mode 100644 index 0000000..242343c --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/xxhash/clib.json @@ -0,0 +1,12 @@ +{ + "name": "xxhash", + "version": "0.8.2", + "repo": "Cyan4973/xxhash", + "description": "Extremely fast non-cryptographic hash algorithm", + "keywords": ["xxhash", "hashing"], + "license": "BSD-2-Clause", + "src": [ + "xxhash.c", + "xxhash.h" + ] +} diff --git a/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c b/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c new file mode 100644 index 0000000..e60cc37 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.c @@ -0,0 +1,42 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Copyright (C) 2012-2023 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/* + * xxhash.c instantiates functions defined in xxhash.h + */ + +#define XXH_STATIC_LINKING_ONLY /* access advanced declarations */ +#define XXH_IMPLEMENTATION /* access definitions */ + +#include "xxhash.h" diff --git a/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h b/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h new file mode 100644 index 0000000..c0fafe2 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/deps/xxhash/xxhash.h @@ -0,0 +1,7093 @@ +/* + * xxHash - Extremely Fast Hash algorithm + * Header File + * Copyright (C) 2012-2023 Yann Collet + * + * BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * You can contact the author at: + * - xxHash homepage: https://www.xxhash.com + * - xxHash source repository: https://github.com/Cyan4973/xxHash + */ + +/*! + * @mainpage xxHash + * + * xxHash is an extremely fast non-cryptographic hash algorithm, working at RAM speed + * limits. + * + * It is proposed in four flavors, in three families: + * 1. @ref XXH32_family + * - Classic 32-bit hash function. Simple, compact, and runs on almost all + * 32-bit and 64-bit systems. + * 2. @ref XXH64_family + * - Classic 64-bit adaptation of XXH32. Just as simple, and runs well on most + * 64-bit systems (but _not_ 32-bit systems). + * 3. @ref XXH3_family + * - Modern 64-bit and 128-bit hash function family which features improved + * strength and performance across the board, especially on smaller data. + * It benefits greatly from SIMD and 64-bit without requiring it. + * + * Benchmarks + * --- + * The reference system uses an Intel i7-9700K CPU, and runs Ubuntu x64 20.04. + * The open source benchmark program is compiled with clang v10.0 using -O3 flag. + * + * | Hash Name | ISA ext | Width | Large Data Speed | Small Data Velocity | + * | -------------------- | ------- | ----: | ---------------: | ------------------: | + * | XXH3_64bits() | @b AVX2 | 64 | 59.4 GB/s | 133.1 | + * | MeowHash | AES-NI | 128 | 58.2 GB/s | 52.5 | + * | XXH3_128bits() | @b AVX2 | 128 | 57.9 GB/s | 118.1 | + * | CLHash | PCLMUL | 64 | 37.1 GB/s | 58.1 | + * | XXH3_64bits() | @b SSE2 | 64 | 31.5 GB/s | 133.1 | + * | XXH3_128bits() | @b SSE2 | 128 | 29.6 GB/s | 118.1 | + * | RAM sequential read | | N/A | 28.0 GB/s | N/A | + * | ahash | AES-NI | 64 | 22.5 GB/s | 107.2 | + * | City64 | | 64 | 22.0 GB/s | 76.6 | + * | T1ha2 | | 64 | 22.0 GB/s | 99.0 | + * | City128 | | 128 | 21.7 GB/s | 57.7 | + * | FarmHash | AES-NI | 64 | 21.3 GB/s | 71.9 | + * | XXH64() | | 64 | 19.4 GB/s | 71.0 | + * | SpookyHash | | 64 | 19.3 GB/s | 53.2 | + * | Mum | | 64 | 18.0 GB/s | 67.0 | + * | CRC32C | SSE4.2 | 32 | 13.0 GB/s | 57.9 | + * | XXH32() | | 32 | 9.7 GB/s | 71.9 | + * | City32 | | 32 | 9.1 GB/s | 66.0 | + * | Blake3* | @b AVX2 | 256 | 4.4 GB/s | 8.1 | + * | Murmur3 | | 32 | 3.9 GB/s | 56.1 | + * | SipHash* | | 64 | 3.0 GB/s | 43.2 | + * | Blake3* | @b SSE2 | 256 | 2.4 GB/s | 8.1 | + * | HighwayHash | | 64 | 1.4 GB/s | 6.0 | + * | FNV64 | | 64 | 1.2 GB/s | 62.7 | + * | Blake2* | | 256 | 1.1 GB/s | 5.1 | + * | SHA1* | | 160 | 0.8 GB/s | 5.6 | + * | MD5* | | 128 | 0.6 GB/s | 7.8 | + * @note + * - Hashes which require a specific ISA extension are noted. SSE2 is also noted, + * even though it is mandatory on x64. + * - Hashes with an asterisk are cryptographic. Note that MD5 is non-cryptographic + * by modern standards. + * - Small data velocity is a rough average of algorithm's efficiency for small + * data. For more accurate information, see the wiki. + * - More benchmarks and strength tests are found on the wiki: + * https://github.com/Cyan4973/xxHash/wiki + * + * Usage + * ------ + * All xxHash variants use a similar API. Changing the algorithm is a trivial + * substitution. + * + * @pre + * For functions which take an input and length parameter, the following + * requirements are assumed: + * - The range from [`input`, `input + length`) is valid, readable memory. + * - The only exception is if the `length` is `0`, `input` may be `NULL`. + * - For C++, the objects must have the *TriviallyCopyable* property, as the + * functions access bytes directly as if it was an array of `unsigned char`. + * + * @anchor single_shot_example + * **Single Shot** + * + * These functions are stateless functions which hash a contiguous block of memory, + * immediately returning the result. They are the easiest and usually the fastest + * option. + * + * XXH32(), XXH64(), XXH3_64bits(), XXH3_128bits() + * + * @code{.c} + * #include <string.h> + * #include "xxhash.h" + * + * // Example for a function which hashes a null terminated string with XXH32(). + * XXH32_hash_t hash_string(const char* string, XXH32_hash_t seed) + * { + * // NULL pointers are only valid if the length is zero + * size_t length = (string == NULL) ? 0 : strlen(string); + * return XXH32(string, length, seed); + * } + * @endcode + * + * + * @anchor streaming_example + * **Streaming** + * + * These groups of functions allow incremental hashing of unknown size, even + * more than what would fit in a size_t. + * + * XXH32_reset(), XXH64_reset(), XXH3_64bits_reset(), XXH3_128bits_reset() + * + * @code{.c} + * #include <stdio.h> + * #include <assert.h> + * #include "xxhash.h" + * // Example for a function which hashes a FILE incrementally with XXH3_64bits(). + * XXH64_hash_t hashFile(FILE* f) + * { + * // Allocate a state struct. Do not just use malloc() or new. + * XXH3_state_t* state = XXH3_createState(); + * assert(state != NULL && "Out of memory!"); + * // Reset the state to start a new hashing session. + * XXH3_64bits_reset(state); + * char buffer[4096]; + * size_t count; + * // Read the file in chunks + * while ((count = fread(buffer, 1, sizeof(buffer), f)) != 0) { + * // Run update() as many times as necessary to process the data + * XXH3_64bits_update(state, buffer, count); + * } + * // Retrieve the finalized hash. This will not change the state. + * XXH64_hash_t result = XXH3_64bits_digest(state); + * // Free the state. Do not use free(). + * XXH3_freeState(state); + * return result; + * } + * @endcode + * + * Streaming functions generate the xxHash value from an incremental input. + * This method is slower than single-call functions, due to state management. + * For small inputs, prefer `XXH32()` and `XXH64()`, which are better optimized. + * + * An XXH state must first be allocated using `XXH*_createState()`. + * + * Start a new hash by initializing the state with a seed using `XXH*_reset()`. + * + * Then, feed the hash state by calling `XXH*_update()` as many times as necessary. + * + * The function returns an error code, with 0 meaning OK, and any other value + * meaning there is an error. + * + * Finally, a hash value can be produced anytime, by using `XXH*_digest()`. + * This function returns the nn-bits hash as an int or long long. + * + * It's still possible to continue inserting input into the hash state after a + * digest, and generate new hash values later on by invoking `XXH*_digest()`. + * + * When done, release the state using `XXH*_freeState()`. + * + * + * @anchor canonical_representation_example + * **Canonical Representation** + * + * The default return values from XXH functions are unsigned 32, 64 and 128 bit + * integers. + * This the simplest and fastest format for further post-processing. + * + * However, this leaves open the question of what is the order on the byte level, + * since little and big endian conventions will store the same number differently. + * + * The canonical representation settles this issue by mandating big-endian + * convention, the same convention as human-readable numbers (large digits first). + * + * When writing hash values to storage, sending them over a network, or printing + * them, it's highly recommended to use the canonical representation to ensure + * portability across a wider range of systems, present and future. + * + * The following functions allow transformation of hash values to and from + * canonical format. + * + * XXH32_canonicalFromHash(), XXH32_hashFromCanonical(), + * XXH64_canonicalFromHash(), XXH64_hashFromCanonical(), + * XXH128_canonicalFromHash(), XXH128_hashFromCanonical(), + * + * @code{.c} + * #include <stdio.h> + * #include "xxhash.h" + * + * // Example for a function which prints XXH32_hash_t in human readable format + * void printXxh32(XXH32_hash_t hash) + * { + * XXH32_canonical_t cano; + * XXH32_canonicalFromHash(&cano, hash); + * size_t i; + * for(i = 0; i < sizeof(cano.digest); ++i) { + * printf("%02x", cano.digest[i]); + * } + * printf("\n"); + * } + * + * // Example for a function which converts XXH32_canonical_t to XXH32_hash_t + * XXH32_hash_t convertCanonicalToXxh32(XXH32_canonical_t cano) + * { + * XXH32_hash_t hash = XXH32_hashFromCanonical(&cano); + * return hash; + * } + * @endcode + * + * + * @file xxhash.h + * xxHash prototypes and implementation + */ + +#if defined (__cplusplus) +extern "C" { +#endif + +/* **************************** + * INLINE mode + ******************************/ +/*! + * @defgroup public Public API + * Contains details on the public xxHash functions. + * @{ + */ +#ifdef XXH_DOXYGEN +/*! + * @brief Gives access to internal state declaration, required for static allocation. + * + * Incompatible with dynamic linking, due to risks of ABI changes. + * + * Usage: + * @code{.c} + * #define XXH_STATIC_LINKING_ONLY + * #include "xxhash.h" + * @endcode + */ +# define XXH_STATIC_LINKING_ONLY +/* Do not undef XXH_STATIC_LINKING_ONLY for Doxygen */ + +/*! + * @brief Gives access to internal definitions. + * + * Usage: + * @code{.c} + * #define XXH_STATIC_LINKING_ONLY + * #define XXH_IMPLEMENTATION + * #include "xxhash.h" + * @endcode + */ +# define XXH_IMPLEMENTATION +/* Do not undef XXH_IMPLEMENTATION for Doxygen */ + +/*! + * @brief Exposes the implementation and marks all functions as `inline`. + * + * Use these build macros to inline xxhash into the target unit. + * Inlining improves performance on small inputs, especially when the length is + * expressed as a compile-time constant: + * + * https://fastcompression.blogspot.com/2018/03/xxhash-for-small-keys-impressive-power.html + * + * It also keeps xxHash symbols private to the unit, so they are not exported. + * + * Usage: + * @code{.c} + * #define XXH_INLINE_ALL + * #include "xxhash.h" + * @endcode + * Do not compile and link xxhash.o as a separate object, as it is not useful. + */ +# define XXH_INLINE_ALL +# undef XXH_INLINE_ALL +/*! + * @brief Exposes the implementation without marking functions as inline. + */ +# define XXH_PRIVATE_API +# undef XXH_PRIVATE_API +/*! + * @brief Emulate a namespace by transparently prefixing all symbols. + * + * If you want to include _and expose_ xxHash functions from within your own + * library, but also want to avoid symbol collisions with other libraries which + * may also include xxHash, you can use @ref XXH_NAMESPACE to automatically prefix + * any public symbol from xxhash library with the value of @ref XXH_NAMESPACE + * (therefore, avoid empty or numeric values). + * + * Note that no change is required within the calling program as long as it + * includes `xxhash.h`: Regular symbol names will be automatically translated + * by this header. + */ +# define XXH_NAMESPACE /* YOUR NAME HERE */ +# undef XXH_NAMESPACE +#endif + +#if (defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API)) \ + && !defined(XXH_INLINE_ALL_31684351384) + /* this section should be traversed only once */ +# define XXH_INLINE_ALL_31684351384 + /* give access to the advanced API, required to compile implementations */ +# undef XXH_STATIC_LINKING_ONLY /* avoid macro redef */ +# define XXH_STATIC_LINKING_ONLY + /* make all functions private */ +# undef XXH_PUBLIC_API +# if defined(__GNUC__) +# define XXH_PUBLIC_API static __inline __attribute__((__unused__)) +# elif defined (__cplusplus) || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) +# define XXH_PUBLIC_API static inline +# elif defined(_MSC_VER) +# define XXH_PUBLIC_API static __inline +# else + /* note: this version may generate warnings for unused static functions */ +# define XXH_PUBLIC_API static +# endif + + /* + * This part deals with the special case where a unit wants to inline xxHash, + * but "xxhash.h" has previously been included without XXH_INLINE_ALL, + * such as part of some previously included *.h header file. + * Without further action, the new include would just be ignored, + * and functions would effectively _not_ be inlined (silent failure). + * The following macros solve this situation by prefixing all inlined names, + * avoiding naming collision with previous inclusions. + */ + /* Before that, we unconditionally #undef all symbols, + * in case they were already defined with XXH_NAMESPACE. + * They will then be redefined for XXH_INLINE_ALL + */ +# undef XXH_versionNumber + /* XXH32 */ +# undef XXH32 +# undef XXH32_createState +# undef XXH32_freeState +# undef XXH32_reset +# undef XXH32_update +# undef XXH32_digest +# undef XXH32_copyState +# undef XXH32_canonicalFromHash +# undef XXH32_hashFromCanonical + /* XXH64 */ +# undef XXH64 +# undef XXH64_createState +# undef XXH64_freeState +# undef XXH64_reset +# undef XXH64_update +# undef XXH64_digest +# undef XXH64_copyState +# undef XXH64_canonicalFromHash +# undef XXH64_hashFromCanonical + /* XXH3_64bits */ +# undef XXH3_64bits +# undef XXH3_64bits_withSecret +# undef XXH3_64bits_withSeed +# undef XXH3_64bits_withSecretandSeed +# undef XXH3_createState +# undef XXH3_freeState +# undef XXH3_copyState +# undef XXH3_64bits_reset +# undef XXH3_64bits_reset_withSeed +# undef XXH3_64bits_reset_withSecret +# undef XXH3_64bits_update +# undef XXH3_64bits_digest +# undef XXH3_generateSecret + /* XXH3_128bits */ +# undef XXH128 +# undef XXH3_128bits +# undef XXH3_128bits_withSeed +# undef XXH3_128bits_withSecret +# undef XXH3_128bits_reset +# undef XXH3_128bits_reset_withSeed +# undef XXH3_128bits_reset_withSecret +# undef XXH3_128bits_reset_withSecretandSeed +# undef XXH3_128bits_update +# undef XXH3_128bits_digest +# undef XXH128_isEqual +# undef XXH128_cmp +# undef XXH128_canonicalFromHash +# undef XXH128_hashFromCanonical + /* Finally, free the namespace itself */ +# undef XXH_NAMESPACE + + /* employ the namespace for XXH_INLINE_ALL */ +# define XXH_NAMESPACE XXH_INLINE_ + /* + * Some identifiers (enums, type names) are not symbols, + * but they must nonetheless be renamed to avoid redeclaration. + * Alternative solution: do not redeclare them. + * However, this requires some #ifdefs, and has a more dispersed impact. + * Meanwhile, renaming can be achieved in a single place. + */ +# define XXH_IPREF(Id) XXH_NAMESPACE ## Id +# define XXH_OK XXH_IPREF(XXH_OK) +# define XXH_ERROR XXH_IPREF(XXH_ERROR) +# define XXH_errorcode XXH_IPREF(XXH_errorcode) +# define XXH32_canonical_t XXH_IPREF(XXH32_canonical_t) +# define XXH64_canonical_t XXH_IPREF(XXH64_canonical_t) +# define XXH128_canonical_t XXH_IPREF(XXH128_canonical_t) +# define XXH32_state_s XXH_IPREF(XXH32_state_s) +# define XXH32_state_t XXH_IPREF(XXH32_state_t) +# define XXH64_state_s XXH_IPREF(XXH64_state_s) +# define XXH64_state_t XXH_IPREF(XXH64_state_t) +# define XXH3_state_s XXH_IPREF(XXH3_state_s) +# define XXH3_state_t XXH_IPREF(XXH3_state_t) +# define XXH128_hash_t XXH_IPREF(XXH128_hash_t) + /* Ensure the header is parsed again, even if it was previously included */ +# undef XXHASH_H_5627135585666179 +# undef XXHASH_H_STATIC_13879238742 +#endif /* XXH_INLINE_ALL || XXH_PRIVATE_API */ + +/* **************************************************************** + * Stable API + *****************************************************************/ +#ifndef XXHASH_H_5627135585666179 +#define XXHASH_H_5627135585666179 1 + +/*! @brief Marks a global symbol. */ +#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API) +# if defined(_WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT)) +# ifdef XXH_EXPORT +# define XXH_PUBLIC_API __declspec(dllexport) +# elif XXH_IMPORT +# define XXH_PUBLIC_API __declspec(dllimport) +# endif +# else +# define XXH_PUBLIC_API /* do nothing */ +# endif +#endif + +#ifdef XXH_NAMESPACE +# define XXH_CAT(A,B) A##B +# define XXH_NAME2(A,B) XXH_CAT(A,B) +# define XXH_versionNumber XXH_NAME2(XXH_NAMESPACE, XXH_versionNumber) +/* XXH32 */ +# define XXH32 XXH_NAME2(XXH_NAMESPACE, XXH32) +# define XXH32_createState XXH_NAME2(XXH_NAMESPACE, XXH32_createState) +# define XXH32_freeState XXH_NAME2(XXH_NAMESPACE, XXH32_freeState) +# define XXH32_reset XXH_NAME2(XXH_NAMESPACE, XXH32_reset) +# define XXH32_update XXH_NAME2(XXH_NAMESPACE, XXH32_update) +# define XXH32_digest XXH_NAME2(XXH_NAMESPACE, XXH32_digest) +# define XXH32_copyState XXH_NAME2(XXH_NAMESPACE, XXH32_copyState) +# define XXH32_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH32_canonicalFromHash) +# define XXH32_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH32_hashFromCanonical) +/* XXH64 */ +# define XXH64 XXH_NAME2(XXH_NAMESPACE, XXH64) +# define XXH64_createState XXH_NAME2(XXH_NAMESPACE, XXH64_createState) +# define XXH64_freeState XXH_NAME2(XXH_NAMESPACE, XXH64_freeState) +# define XXH64_reset XXH_NAME2(XXH_NAMESPACE, XXH64_reset) +# define XXH64_update XXH_NAME2(XXH_NAMESPACE, XXH64_update) +# define XXH64_digest XXH_NAME2(XXH_NAMESPACE, XXH64_digest) +# define XXH64_copyState XXH_NAME2(XXH_NAMESPACE, XXH64_copyState) +# define XXH64_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH64_canonicalFromHash) +# define XXH64_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH64_hashFromCanonical) +/* XXH3_64bits */ +# define XXH3_64bits XXH_NAME2(XXH_NAMESPACE, XXH3_64bits) +# define XXH3_64bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecret) +# define XXH3_64bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSeed) +# define XXH3_64bits_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_withSecretandSeed) +# define XXH3_createState XXH_NAME2(XXH_NAMESPACE, XXH3_createState) +# define XXH3_freeState XXH_NAME2(XXH_NAMESPACE, XXH3_freeState) +# define XXH3_copyState XXH_NAME2(XXH_NAMESPACE, XXH3_copyState) +# define XXH3_64bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset) +# define XXH3_64bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSeed) +# define XXH3_64bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecret) +# define XXH3_64bits_reset_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_reset_withSecretandSeed) +# define XXH3_64bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_update) +# define XXH3_64bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_64bits_digest) +# define XXH3_generateSecret XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret) +# define XXH3_generateSecret_fromSeed XXH_NAME2(XXH_NAMESPACE, XXH3_generateSecret_fromSeed) +/* XXH3_128bits */ +# define XXH128 XXH_NAME2(XXH_NAMESPACE, XXH128) +# define XXH3_128bits XXH_NAME2(XXH_NAMESPACE, XXH3_128bits) +# define XXH3_128bits_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSeed) +# define XXH3_128bits_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecret) +# define XXH3_128bits_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_withSecretandSeed) +# define XXH3_128bits_reset XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset) +# define XXH3_128bits_reset_withSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSeed) +# define XXH3_128bits_reset_withSecret XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecret) +# define XXH3_128bits_reset_withSecretandSeed XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_reset_withSecretandSeed) +# define XXH3_128bits_update XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_update) +# define XXH3_128bits_digest XXH_NAME2(XXH_NAMESPACE, XXH3_128bits_digest) +# define XXH128_isEqual XXH_NAME2(XXH_NAMESPACE, XXH128_isEqual) +# define XXH128_cmp XXH_NAME2(XXH_NAMESPACE, XXH128_cmp) +# define XXH128_canonicalFromHash XXH_NAME2(XXH_NAMESPACE, XXH128_canonicalFromHash) +# define XXH128_hashFromCanonical XXH_NAME2(XXH_NAMESPACE, XXH128_hashFromCanonical) +#endif + + +/* ************************************* +* Compiler specifics +***************************************/ + +/* specific declaration modes for Windows */ +#if !defined(XXH_INLINE_ALL) && !defined(XXH_PRIVATE_API) +# if defined(_WIN32) && defined(_MSC_VER) && (defined(XXH_IMPORT) || defined(XXH_EXPORT)) +# ifdef XXH_EXPORT +# define XXH_PUBLIC_API __declspec(dllexport) +# elif XXH_IMPORT +# define XXH_PUBLIC_API __declspec(dllimport) +# endif +# else +# define XXH_PUBLIC_API /* do nothing */ +# endif +#endif + +#if defined (__GNUC__) +# define XXH_CONSTF __attribute__((__const__)) +# define XXH_PUREF __attribute__((__pure__)) +# define XXH_MALLOCF __attribute__((__malloc__)) +#else +# define XXH_CONSTF /* disable */ +# define XXH_PUREF +# define XXH_MALLOCF +#endif + +/* ************************************* +* Version +***************************************/ +#define XXH_VERSION_MAJOR 0 +#define XXH_VERSION_MINOR 8 +#define XXH_VERSION_RELEASE 3 +/*! @brief Version number, encoded as two digits each */ +#define XXH_VERSION_NUMBER (XXH_VERSION_MAJOR *100*100 + XXH_VERSION_MINOR *100 + XXH_VERSION_RELEASE) + +/*! + * @brief Obtains the xxHash version. + * + * This is mostly useful when xxHash is compiled as a shared library, + * since the returned value comes from the library, as opposed to header file. + * + * @return @ref XXH_VERSION_NUMBER of the invoked library. + */ +XXH_PUBLIC_API XXH_CONSTF unsigned XXH_versionNumber (void); + + +/* **************************** +* Common basic types +******************************/ +#include <stddef.h> /* size_t */ +/*! + * @brief Exit code for the streaming API. + */ +typedef enum { + XXH_OK = 0, /*!< OK */ + XXH_ERROR /*!< Error */ +} XXH_errorcode; + + +/*-********************************************************************** +* 32-bit hash +************************************************************************/ +#if defined(XXH_DOXYGEN) /* Don't show <stdint.h> include */ +/*! + * @brief An unsigned 32-bit integer. + * + * Not necessarily defined to `uint32_t` but functionally equivalent. + */ +typedef uint32_t XXH32_hash_t; + +#elif !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# ifdef _AIX +# include <inttypes.h> +# else +# include <stdint.h> +# endif + typedef uint32_t XXH32_hash_t; + +#else +# include <limits.h> +# if UINT_MAX == 0xFFFFFFFFUL + typedef unsigned int XXH32_hash_t; +# elif ULONG_MAX == 0xFFFFFFFFUL + typedef unsigned long XXH32_hash_t; +# else +# error "unsupported platform: need a 32-bit type" +# endif +#endif + +/*! + * @} + * + * @defgroup XXH32_family XXH32 family + * @ingroup public + * Contains functions used in the classic 32-bit xxHash algorithm. + * + * @note + * XXH32 is useful for older platforms, with no or poor 64-bit performance. + * Note that the @ref XXH3_family provides competitive speed for both 32-bit + * and 64-bit systems, and offers true 64/128 bit hash results. + * + * @see @ref XXH64_family, @ref XXH3_family : Other xxHash families + * @see @ref XXH32_impl for implementation details + * @{ + */ + +/*! + * @brief Calculates the 32-bit hash of @p input using xxHash32. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * @param seed The 32-bit seed to alter the hash's output predictably. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 32-bit xxHash32 value. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH32_hash_t XXH32 (const void* input, size_t length, XXH32_hash_t seed); + +#ifndef XXH_NO_STREAM +/*! + * @typedef struct XXH32_state_s XXH32_state_t + * @brief The opaque state struct for the XXH32 streaming API. + * + * @see XXH32_state_s for details. + * @see @ref streaming_example "Streaming Example" + */ +typedef struct XXH32_state_s XXH32_state_t; + +/*! + * @brief Allocates an @ref XXH32_state_t. + * + * @return An allocated pointer of @ref XXH32_state_t on success. + * @return `NULL` on failure. + * + * @note Must be freed with XXH32_freeState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_MALLOCF XXH32_state_t* XXH32_createState(void); +/*! + * @brief Frees an @ref XXH32_state_t. + * + * @param statePtr A pointer to an @ref XXH32_state_t allocated with @ref XXH32_createState(). + * + * @return @ref XXH_OK. + * + * @note @p statePtr must be allocated with XXH32_createState(). + * + * @see @ref streaming_example "Streaming Example" + * + */ +XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr); +/*! + * @brief Copies one @ref XXH32_state_t to another. + * + * @param dst_state The state to copy to. + * @param src_state The state to copy from. + * @pre + * @p dst_state and @p src_state must not be `NULL` and must not overlap. + */ +XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dst_state, const XXH32_state_t* src_state); + +/*! + * @brief Resets an @ref XXH32_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 32-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note This function resets and seeds a state. Call it before @ref XXH32_update(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH32_reset (XXH32_state_t* statePtr, XXH32_hash_t seed); + +/*! + * @brief Consumes a block of @p input to an @ref XXH32_state_t. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note Call this to incrementally consume blocks of data. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH32_update (XXH32_state_t* statePtr, const void* input, size_t length); + +/*! + * @brief Returns the calculated hash value from an @ref XXH32_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated 32-bit xxHash32 value from that state. + * + * @note + * Calling XXH32_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH32_hash_t XXH32_digest (const XXH32_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ + +/******* Canonical representation *******/ + +/*! + * @brief Canonical (big endian) representation of @ref XXH32_hash_t. + */ +typedef struct { + unsigned char digest[4]; /*!< Hash bytes, big endian */ +} XXH32_canonical_t; + +/*! + * @brief Converts an @ref XXH32_hash_t to a big endian @ref XXH32_canonical_t. + * + * @param dst The @ref XXH32_canonical_t pointer to be stored to. + * @param hash The @ref XXH32_hash_t to be converted. + * + * @pre + * @p dst must not be `NULL`. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash); + +/*! + * @brief Converts an @ref XXH32_canonical_t to a native @ref XXH32_hash_t. + * + * @param src The @ref XXH32_canonical_t to convert. + * + * @pre + * @p src must not be `NULL`. + * + * @return The converted hash. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src); + + +/*! @cond Doxygen ignores this part */ +#ifdef __has_attribute +# define XXH_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +# define XXH_HAS_ATTRIBUTE(x) 0 +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* + * C23 __STDC_VERSION__ number hasn't been specified yet. For now + * leave as `201711L` (C17 + 1). + * TODO: Update to correct value when its been specified. + */ +#define XXH_C23_VN 201711L +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* C-language Attributes are added in C23. */ +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= XXH_C23_VN) && defined(__has_c_attribute) +# define XXH_HAS_C_ATTRIBUTE(x) __has_c_attribute(x) +#else +# define XXH_HAS_C_ATTRIBUTE(x) 0 +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +#if defined(__cplusplus) && defined(__has_cpp_attribute) +# define XXH_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define XXH_HAS_CPP_ATTRIBUTE(x) 0 +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* + * Define XXH_FALLTHROUGH macro for annotating switch case with the 'fallthrough' attribute + * introduced in CPP17 and C23. + * CPP17 : https://en.cppreference.com/w/cpp/language/attributes/fallthrough + * C23 : https://en.cppreference.com/w/c/language/attributes/fallthrough + */ +#if XXH_HAS_C_ATTRIBUTE(fallthrough) || XXH_HAS_CPP_ATTRIBUTE(fallthrough) +# define XXH_FALLTHROUGH [[fallthrough]] +#elif XXH_HAS_ATTRIBUTE(__fallthrough__) +# define XXH_FALLTHROUGH __attribute__ ((__fallthrough__)) +#else +# define XXH_FALLTHROUGH /* fallthrough */ +#endif +/*! @endcond */ + +/*! @cond Doxygen ignores this part */ +/* + * Define XXH_NOESCAPE for annotated pointers in public API. + * https://clang.llvm.org/docs/AttributeReference.html#noescape + * As of writing this, only supported by clang. + */ +#if XXH_HAS_ATTRIBUTE(noescape) +# define XXH_NOESCAPE __attribute__((__noescape__)) +#else +# define XXH_NOESCAPE +#endif +/*! @endcond */ + + +/*! + * @} + * @ingroup public + * @{ + */ + +#ifndef XXH_NO_LONG_LONG +/*-********************************************************************** +* 64-bit hash +************************************************************************/ +#if defined(XXH_DOXYGEN) /* don't include <stdint.h> */ +/*! + * @brief An unsigned 64-bit integer. + * + * Not necessarily defined to `uint64_t` but functionally equivalent. + */ +typedef uint64_t XXH64_hash_t; +#elif !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# ifdef _AIX +# include <inttypes.h> +# else +# include <stdint.h> +# endif + typedef uint64_t XXH64_hash_t; +#else +# include <limits.h> +# if defined(__LP64__) && ULONG_MAX == 0xFFFFFFFFFFFFFFFFULL + /* LP64 ABI says uint64_t is unsigned long */ + typedef unsigned long XXH64_hash_t; +# else + /* the following type must have a width of 64-bit */ + typedef unsigned long long XXH64_hash_t; +# endif +#endif + +/*! + * @} + * + * @defgroup XXH64_family XXH64 family + * @ingroup public + * @{ + * Contains functions used in the classic 64-bit xxHash algorithm. + * + * @note + * XXH3 provides competitive speed for both 32-bit and 64-bit systems, + * and offers true 64/128 bit hash results. + * It provides better speed for systems with vector processing capabilities. + */ + +/*! + * @brief Calculates the 64-bit hash of @p input using xxHash64. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * @param seed The 64-bit seed to alter the hash's output predictably. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 64-bit xxHash64 value. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH64(XXH_NOESCAPE const void* input, size_t length, XXH64_hash_t seed); + +/******* Streaming *******/ +#ifndef XXH_NO_STREAM +/*! + * @brief The opaque state struct for the XXH64 streaming API. + * + * @see XXH64_state_s for details. + * @see @ref streaming_example "Streaming Example" + */ +typedef struct XXH64_state_s XXH64_state_t; /* incomplete type */ + +/*! + * @brief Allocates an @ref XXH64_state_t. + * + * @return An allocated pointer of @ref XXH64_state_t on success. + * @return `NULL` on failure. + * + * @note Must be freed with XXH64_freeState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_MALLOCF XXH64_state_t* XXH64_createState(void); + +/*! + * @brief Frees an @ref XXH64_state_t. + * + * @param statePtr A pointer to an @ref XXH64_state_t allocated with @ref XXH64_createState(). + * + * @return @ref XXH_OK. + * + * @note @p statePtr must be allocated with XXH64_createState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr); + +/*! + * @brief Copies one @ref XXH64_state_t to another. + * + * @param dst_state The state to copy to. + * @param src_state The state to copy from. + * @pre + * @p dst_state and @p src_state must not be `NULL` and must not overlap. + */ +XXH_PUBLIC_API void XXH64_copyState(XXH_NOESCAPE XXH64_state_t* dst_state, const XXH64_state_t* src_state); + +/*! + * @brief Resets an @ref XXH64_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note This function resets and seeds a state. Call it before @ref XXH64_update(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH64_reset (XXH_NOESCAPE XXH64_state_t* statePtr, XXH64_hash_t seed); + +/*! + * @brief Consumes a block of @p input to an @ref XXH64_state_t. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note Call this to incrementally consume blocks of data. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH64_update (XXH_NOESCAPE XXH64_state_t* statePtr, XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Returns the calculated hash value from an @ref XXH64_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated 64-bit xxHash64 value from that state. + * + * @note + * Calling XXH64_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH64_digest (XXH_NOESCAPE const XXH64_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ +/******* Canonical representation *******/ + +/*! + * @brief Canonical (big endian) representation of @ref XXH64_hash_t. + */ +typedef struct { unsigned char digest[sizeof(XXH64_hash_t)]; } XXH64_canonical_t; + +/*! + * @brief Converts an @ref XXH64_hash_t to a big endian @ref XXH64_canonical_t. + * + * @param dst The @ref XXH64_canonical_t pointer to be stored to. + * @param hash The @ref XXH64_hash_t to be converted. + * + * @pre + * @p dst must not be `NULL`. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH_NOESCAPE XXH64_canonical_t* dst, XXH64_hash_t hash); + +/*! + * @brief Converts an @ref XXH64_canonical_t to a native @ref XXH64_hash_t. + * + * @param src The @ref XXH64_canonical_t to convert. + * + * @pre + * @p src must not be `NULL`. + * + * @return The converted hash. + * + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH64_hashFromCanonical(XXH_NOESCAPE const XXH64_canonical_t* src); + +#ifndef XXH_NO_XXH3 + +/*! + * @} + * ************************************************************************ + * @defgroup XXH3_family XXH3 family + * @ingroup public + * @{ + * + * XXH3 is a more recent hash algorithm featuring: + * - Improved speed for both small and large inputs + * - True 64-bit and 128-bit outputs + * - SIMD acceleration + * - Improved 32-bit viability + * + * Speed analysis methodology is explained here: + * + * https://fastcompression.blogspot.com/2019/03/presenting-xxh3.html + * + * Compared to XXH64, expect XXH3 to run approximately + * ~2x faster on large inputs and >3x faster on small ones, + * exact differences vary depending on platform. + * + * XXH3's speed benefits greatly from SIMD and 64-bit arithmetic, + * but does not require it. + * Most 32-bit and 64-bit targets that can run XXH32 smoothly can run XXH3 + * at competitive speeds, even without vector support. Further details are + * explained in the implementation. + * + * XXH3 has a fast scalar implementation, but it also includes accelerated SIMD + * implementations for many common platforms: + * - AVX512 + * - AVX2 + * - SSE2 + * - ARM NEON + * - WebAssembly SIMD128 + * - POWER8 VSX + * - s390x ZVector + * This can be controlled via the @ref XXH_VECTOR macro, but it automatically + * selects the best version according to predefined macros. For the x86 family, an + * automatic runtime dispatcher is included separately in @ref xxh_x86dispatch.c. + * + * XXH3 implementation is portable: + * it has a generic C90 formulation that can be compiled on any platform, + * all implementations generate exactly the same hash value on all platforms. + * Starting from v0.8.0, it's also labelled "stable", meaning that + * any future version will also generate the same hash value. + * + * XXH3 offers 2 variants, _64bits and _128bits. + * + * When only 64 bits are needed, prefer invoking the _64bits variant, as it + * reduces the amount of mixing, resulting in faster speed on small inputs. + * It's also generally simpler to manipulate a scalar return type than a struct. + * + * The API supports one-shot hashing, streaming mode, and custom secrets. + */ +/*-********************************************************************** +* XXH3 64-bit variant +************************************************************************/ + +/*! + * @brief Calculates 64-bit unseeded variant of XXH3 hash of @p input. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 64-bit XXH3 hash value. + * + * @note + * This is equivalent to @ref XXH3_64bits_withSeed() with a seed of `0`, however + * it may have slightly better performance due to constant propagation of the + * defaults. + * + * @see + * XXH3_64bits_withSeed(), XXH3_64bits_withSecret(): other seeding variants + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits(XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Calculates 64-bit seeded variant of XXH3 hash of @p input. + * + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 64-bit XXH3 hash value. + * + * @note + * seed == 0 produces the same results as @ref XXH3_64bits(). + * + * This variant generates a custom secret on the fly based on default secret + * altered using the @p seed value. + * + * While this operation is decently fast, note that it's not completely free. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits_withSeed(XXH_NOESCAPE const void* input, size_t length, XXH64_hash_t seed); + +/*! + * The bare minimum size for a custom secret. + * + * @see + * XXH3_64bits_withSecret(), XXH3_64bits_reset_withSecret(), + * XXH3_128bits_withSecret(), XXH3_128bits_reset_withSecret(). + */ +#define XXH3_SECRET_SIZE_MIN 136 + +/*! + * @brief Calculates 64-bit variant of XXH3 with a custom "secret". + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @return The calculated 64-bit XXH3 hash value. + * + * @pre + * The memory between @p data and @p data + @p len must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p data may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * It's possible to provide any blob of bytes as a "secret" to generate the hash. + * This makes it more difficult for an external actor to prepare an intentional collision. + * The main condition is that @p secretSize *must* be large enough (>= @ref XXH3_SECRET_SIZE_MIN). + * However, the quality of the secret impacts the dispersion of the hash algorithm. + * Therefore, the secret _must_ look like a bunch of random bytes. + * Avoid "trivial" or structured data such as repeated sequences or a text document. + * Whenever in doubt about the "randomness" of the blob of bytes, + * consider employing @ref XXH3_generateSecret() instead (see below). + * It will generate a proper high entropy secret derived from the blob of bytes. + * Another advantage of using XXH3_generateSecret() is that + * it guarantees that all bits within the initial blob of bytes + * will impact every bit of the output. + * This is not necessarily the case when using the blob of bytes directly + * because, when hashing _small_ inputs, only a portion of the secret is employed. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits_withSecret(XXH_NOESCAPE const void* data, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize); + + +/******* Streaming *******/ +#ifndef XXH_NO_STREAM +/* + * Streaming requires state maintenance. + * This operation costs memory and CPU. + * As a consequence, streaming is slower than one-shot hashing. + * For better performance, prefer one-shot functions whenever applicable. + */ + +/*! + * @brief The opaque state struct for the XXH3 streaming API. + * + * @see XXH3_state_s for details. + * @see @ref streaming_example "Streaming Example" + */ +typedef struct XXH3_state_s XXH3_state_t; +XXH_PUBLIC_API XXH_MALLOCF XXH3_state_t* XXH3_createState(void); +XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr); + +/*! + * @brief Copies one @ref XXH3_state_t to another. + * + * @param dst_state The state to copy to. + * @param src_state The state to copy from. + * @pre + * @p dst_state and @p src_state must not be `NULL` and must not overlap. + */ +XXH_PUBLIC_API void XXH3_copyState(XXH_NOESCAPE XXH3_state_t* dst_state, XXH_NOESCAPE const XXH3_state_t* src_state); + +/*! + * @brief Resets an @ref XXH3_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret with default parameters. + * - Call this function before @ref XXH3_64bits_update(). + * - Digest will be equivalent to `XXH3_64bits()`. + * + * @see @ref streaming_example "Streaming Example" + * + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr); + +/*! + * @brief Resets an @ref XXH3_state_t with 64-bit seed to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret from `seed`. + * - Call this function before @ref XXH3_64bits_update(). + * - Digest will be equivalent to `XXH3_64bits_withSeed()`. + * + * @see @ref streaming_example "Streaming Example" + * + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed); + +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * `secret` is referenced, it _must outlive_ the hash streaming session. + * + * Similar to one-shot API, `secretSize` must be >= @ref XXH3_SECRET_SIZE_MIN, + * and the quality of produced hash values depends on secret's entropy + * (secret's content should look like a bunch of random bytes). + * When in doubt about the randomness of a candidate `secret`, + * consider employing `XXH3_generateSecret()` instead (see below). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize); + +/*! + * @brief Consumes a block of @p input to an @ref XXH3_state_t. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * @pre + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note Call this to incrementally consume blocks of data. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_64bits_update (XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Returns the calculated XXH3 64-bit hash value from an @ref XXH3_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated XXH3 64-bit hash value from that state. + * + * @note + * Calling XXH3_64bits_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t XXH3_64bits_digest (XXH_NOESCAPE const XXH3_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ + +/* note : canonical representation of XXH3 is the same as XXH64 + * since they both produce XXH64_hash_t values */ + + +/*-********************************************************************** +* XXH3 128-bit variant +************************************************************************/ + +/*! + * @brief The return value from 128-bit hashes. + * + * Stored in little endian order, although the fields themselves are in native + * endianness. + */ +typedef struct { + XXH64_hash_t low64; /*!< `value & 0xFFFFFFFFFFFFFFFF` */ + XXH64_hash_t high64; /*!< `value >> 64` */ +} XXH128_hash_t; + +/*! + * @brief Calculates 128-bit unseeded variant of XXH3 of @p data. + * + * @param data The block of data to be hashed, at least @p length bytes in size. + * @param len The length of @p data, in bytes. + * + * @return The calculated 128-bit variant of XXH3 value. + * + * The 128-bit variant of XXH3 has more strength, but it has a bit of overhead + * for shorter inputs. + * + * This is equivalent to @ref XXH3_128bits_withSeed() with a seed of `0`, however + * it may have slightly better performance due to constant propagation of the + * defaults. + * + * @see XXH3_128bits_withSeed(), XXH3_128bits_withSecret(): other seeding variants + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits(XXH_NOESCAPE const void* data, size_t len); +/*! @brief Calculates 128-bit seeded variant of XXH3 hash of @p data. + * + * @param data The block of data to be hashed, at least @p length bytes in size. + * @param len The length of @p data, in bytes. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @return The calculated 128-bit variant of XXH3 value. + * + * @note + * seed == 0 produces the same results as @ref XXH3_64bits(). + * + * This variant generates a custom secret on the fly based on default secret + * altered using the @p seed value. + * + * While this operation is decently fast, note that it's not completely free. + * + * @see XXH3_128bits(), XXH3_128bits_withSecret(): other seeding variants + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits_withSeed(XXH_NOESCAPE const void* data, size_t len, XXH64_hash_t seed); +/*! + * @brief Calculates 128-bit variant of XXH3 with a custom "secret". + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @return The calculated 128-bit variant of XXH3 value. + * + * It's possible to provide any blob of bytes as a "secret" to generate the hash. + * This makes it more difficult for an external actor to prepare an intentional collision. + * The main condition is that @p secretSize *must* be large enough (>= @ref XXH3_SECRET_SIZE_MIN). + * However, the quality of the secret impacts the dispersion of the hash algorithm. + * Therefore, the secret _must_ look like a bunch of random bytes. + * Avoid "trivial" or structured data such as repeated sequences or a text document. + * Whenever in doubt about the "randomness" of the blob of bytes, + * consider employing @ref XXH3_generateSecret() instead (see below). + * It will generate a proper high entropy secret derived from the blob of bytes. + * Another advantage of using XXH3_generateSecret() is that + * it guarantees that all bits within the initial blob of bytes + * will impact every bit of the output. + * This is not necessarily the case when using the blob of bytes directly + * because, when hashing _small_ inputs, only a portion of the secret is employed. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits_withSecret(XXH_NOESCAPE const void* data, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize); + +/******* Streaming *******/ +#ifndef XXH_NO_STREAM +/* + * Streaming requires state maintenance. + * This operation costs memory and CPU. + * As a consequence, streaming is slower than one-shot hashing. + * For better performance, prefer one-shot functions whenever applicable. + * + * XXH3_128bits uses the same XXH3_state_t as XXH3_64bits(). + * Use already declared XXH3_createState() and XXH3_freeState(). + * + * All reset and streaming functions have same meaning as their 64-bit counterpart. + */ + +/*! + * @brief Resets an @ref XXH3_state_t to begin a new hash. + * + * @param statePtr The state struct to reset. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret with default parameters. + * - Call it before @ref XXH3_128bits_update(). + * - Digest will be equivalent to `XXH3_128bits()`. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr); + +/*! + * @brief Resets an @ref XXH3_state_t with 64-bit seed to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * - This function resets `statePtr` and generate a secret from `seed`. + * - Call it before @ref XXH3_128bits_update(). + * - Digest will be equivalent to `XXH3_128bits_withSeed()`. + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed); +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr The state struct to reset. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * `secret` is referenced, it _must outlive_ the hash streaming session. + * Similar to one-shot API, `secretSize` must be >= @ref XXH3_SECRET_SIZE_MIN, + * and the quality of produced hash values depends on secret's entropy + * (secret's content should look like a bunch of random bytes). + * When in doubt about the randomness of a candidate `secret`, + * consider employing `XXH3_generateSecret()` instead (see below). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize); + +/*! + * @brief Consumes a block of @p input to an @ref XXH3_state_t. + * + * Call this to incrementally consume blocks of data. + * + * @param statePtr The state struct to update. + * @param input The block of data to be hashed, at least @p length bytes in size. + * @param length The length of @p input, in bytes. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @note + * The memory between @p input and @p input + @p length must be valid, + * readable, contiguous memory. However, if @p length is `0`, @p input may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + */ +XXH_PUBLIC_API XXH_errorcode XXH3_128bits_update (XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* input, size_t length); + +/*! + * @brief Returns the calculated XXH3 128-bit hash value from an @ref XXH3_state_t. + * + * @param statePtr The state struct to calculate the hash from. + * + * @pre + * @p statePtr must not be `NULL`. + * + * @return The calculated XXH3 128-bit hash value from that state. + * + * @note + * Calling XXH3_128bits_digest() will not affect @p statePtr, so you can update, + * digest, and update again. + * + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH3_128bits_digest (XXH_NOESCAPE const XXH3_state_t* statePtr); +#endif /* !XXH_NO_STREAM */ + +/* Following helper functions make it possible to compare XXH128_hast_t values. + * Since XXH128_hash_t is a structure, this capability is not offered by the language. + * Note: For better performance, these functions can be inlined using XXH_INLINE_ALL */ + +/*! + * @brief Check equality of two XXH128_hash_t values + * + * @param h1 The 128-bit hash value. + * @param h2 Another 128-bit hash value. + * + * @return `1` if `h1` and `h2` are equal. + * @return `0` if they are not. + */ +XXH_PUBLIC_API XXH_PUREF int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2); + +/*! + * @brief Compares two @ref XXH128_hash_t + * + * This comparator is compatible with stdlib's `qsort()`/`bsearch()`. + * + * @param h128_1 Left-hand side value + * @param h128_2 Right-hand side value + * + * @return >0 if @p h128_1 > @p h128_2 + * @return =0 if @p h128_1 == @p h128_2 + * @return <0 if @p h128_1 < @p h128_2 + */ +XXH_PUBLIC_API XXH_PUREF int XXH128_cmp(XXH_NOESCAPE const void* h128_1, XXH_NOESCAPE const void* h128_2); + + +/******* Canonical representation *******/ +typedef struct { unsigned char digest[sizeof(XXH128_hash_t)]; } XXH128_canonical_t; + + +/*! + * @brief Converts an @ref XXH128_hash_t to a big endian @ref XXH128_canonical_t. + * + * @param dst The @ref XXH128_canonical_t pointer to be stored to. + * @param hash The @ref XXH128_hash_t to be converted. + * + * @pre + * @p dst must not be `NULL`. + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API void XXH128_canonicalFromHash(XXH_NOESCAPE XXH128_canonical_t* dst, XXH128_hash_t hash); + +/*! + * @brief Converts an @ref XXH128_canonical_t to a native @ref XXH128_hash_t. + * + * @param src The @ref XXH128_canonical_t to convert. + * + * @pre + * @p src must not be `NULL`. + * + * @return The converted hash. + * @see @ref canonical_representation_example "Canonical Representation Example" + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH128_hashFromCanonical(XXH_NOESCAPE const XXH128_canonical_t* src); + + +#endif /* !XXH_NO_XXH3 */ +#endif /* XXH_NO_LONG_LONG */ + +/*! + * @} + */ +#endif /* XXHASH_H_5627135585666179 */ + + + +#if defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) +#define XXHASH_H_STATIC_13879238742 +/* **************************************************************************** + * This section contains declarations which are not guaranteed to remain stable. + * They may change in future versions, becoming incompatible with a different + * version of the library. + * These declarations should only be used with static linking. + * Never use them in association with dynamic linking! + ***************************************************************************** */ + +/* + * These definitions are only present to allow static allocation + * of XXH states, on stack or in a struct, for example. + * Never **ever** access their members directly. + */ + +/*! + * @internal + * @brief Structure for XXH32 streaming API. + * + * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY, + * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is + * an opaque type. This allows fields to safely be changed. + * + * Typedef'd to @ref XXH32_state_t. + * Do not access the members of this struct directly. + * @see XXH64_state_s, XXH3_state_s + */ +struct XXH32_state_s { + XXH32_hash_t total_len_32; /*!< Total length hashed, modulo 2^32 */ + XXH32_hash_t large_len; /*!< Whether the hash is >= 16 (handles @ref total_len_32 overflow) */ + XXH32_hash_t v[4]; /*!< Accumulator lanes */ + XXH32_hash_t mem32[4]; /*!< Internal buffer for partial reads. Treated as unsigned char[16]. */ + XXH32_hash_t memsize; /*!< Amount of data in @ref mem32 */ + XXH32_hash_t reserved; /*!< Reserved field. Do not read nor write to it. */ +}; /* typedef'd to XXH32_state_t */ + + +#ifndef XXH_NO_LONG_LONG /* defined when there is no 64-bit support */ + +/*! + * @internal + * @brief Structure for XXH64 streaming API. + * + * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY, + * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. Otherwise it is + * an opaque type. This allows fields to safely be changed. + * + * Typedef'd to @ref XXH64_state_t. + * Do not access the members of this struct directly. + * @see XXH32_state_s, XXH3_state_s + */ +struct XXH64_state_s { + XXH64_hash_t total_len; /*!< Total length hashed. This is always 64-bit. */ + XXH64_hash_t v[4]; /*!< Accumulator lanes */ + XXH64_hash_t mem64[4]; /*!< Internal buffer for partial reads. Treated as unsigned char[32]. */ + XXH32_hash_t memsize; /*!< Amount of data in @ref mem64 */ + XXH32_hash_t reserved32; /*!< Reserved field, needed for padding anyways*/ + XXH64_hash_t reserved64; /*!< Reserved field. Do not read or write to it. */ +}; /* typedef'd to XXH64_state_t */ + +#ifndef XXH_NO_XXH3 + +/* Windows SDK under 10.0.22000 is missing stdalign.h so we add a check + before allowing the windows compiler to use the C11 form. + Reference: https://github.com/Cyan4973/xxHash/issues/955 */ +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) \ + && (defined(_MSC_VER) && (_MSC_VER >= 1000) || !defined(_MSC_VER)) /* >= C11 */ +# include <stdalign.h> +# define XXH_ALIGN(n) alignas(n) +#elif defined(__cplusplus) && (__cplusplus >= 201103L) /* >= C++11 */ +/* In C++ alignas() is a keyword */ +# define XXH_ALIGN(n) alignas(n) +#elif defined(__GNUC__) +# define XXH_ALIGN(n) __attribute__ ((aligned(n))) +#elif defined(_MSC_VER) +# define XXH_ALIGN(n) __declspec(align(n)) +#else +# define XXH_ALIGN(n) /* disabled */ +#endif + +/* Old GCC versions only accept the attribute after the type in structures. */ +#if !(defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) /* C11+ */ \ + && ! (defined(__cplusplus) && (__cplusplus >= 201103L)) /* >= C++11 */ \ + && defined(__GNUC__) +# define XXH_ALIGN_MEMBER(align, type) type XXH_ALIGN(align) +#else +# define XXH_ALIGN_MEMBER(align, type) XXH_ALIGN(align) type +#endif + +/*! + * @brief The size of the internal XXH3 buffer. + * + * This is the optimal update size for incremental hashing. + * + * @see XXH3_64b_update(), XXH3_128b_update(). + */ +#define XXH3_INTERNALBUFFER_SIZE 256 + +/*! + * @internal + * @brief Default size of the secret buffer (and @ref XXH3_kSecret). + * + * This is the size used in @ref XXH3_kSecret and the seeded functions. + * + * Not to be confused with @ref XXH3_SECRET_SIZE_MIN. + */ +#define XXH3_SECRET_DEFAULT_SIZE 192 + +/*! + * @internal + * @brief Structure for XXH3 streaming API. + * + * @note This is only defined when @ref XXH_STATIC_LINKING_ONLY, + * @ref XXH_INLINE_ALL, or @ref XXH_IMPLEMENTATION is defined. + * Otherwise it is an opaque type. + * Never use this definition in combination with dynamic library. + * This allows fields to safely be changed in the future. + * + * @note ** This structure has a strict alignment requirement of 64 bytes!! ** + * Do not allocate this with `malloc()` or `new`, + * it will not be sufficiently aligned. + * Use @ref XXH3_createState() and @ref XXH3_freeState(), or stack allocation. + * + * Typedef'd to @ref XXH3_state_t. + * Do never access the members of this struct directly. + * + * @see XXH3_INITSTATE() for stack initialization. + * @see XXH3_createState(), XXH3_freeState(). + * @see XXH32_state_s, XXH64_state_s + */ +struct XXH3_state_s { + XXH_ALIGN_MEMBER(64, XXH64_hash_t acc[8]); + /*!< The 8 accumulators. See @ref XXH32_state_s::v and @ref XXH64_state_s::v */ + XXH_ALIGN_MEMBER(64, unsigned char customSecret[XXH3_SECRET_DEFAULT_SIZE]); + /*!< Used to store a custom secret generated from a seed. */ + XXH_ALIGN_MEMBER(64, unsigned char buffer[XXH3_INTERNALBUFFER_SIZE]); + /*!< The internal buffer. @see XXH32_state_s::mem32 */ + XXH32_hash_t bufferedSize; + /*!< The amount of memory in @ref buffer, @see XXH32_state_s::memsize */ + XXH32_hash_t useSeed; + /*!< Reserved field. Needed for padding on 64-bit. */ + size_t nbStripesSoFar; + /*!< Number or stripes processed. */ + XXH64_hash_t totalLen; + /*!< Total length hashed. 64-bit even on 32-bit targets. */ + size_t nbStripesPerBlock; + /*!< Number of stripes per block. */ + size_t secretLimit; + /*!< Size of @ref customSecret or @ref extSecret */ + XXH64_hash_t seed; + /*!< Seed for _withSeed variants. Must be zero otherwise, @see XXH3_INITSTATE() */ + XXH64_hash_t reserved64; + /*!< Reserved field. */ + const unsigned char* extSecret; + /*!< Reference to an external secret for the _withSecret variants, NULL + * for other variants. */ + /* note: there may be some padding at the end due to alignment on 64 bytes */ +}; /* typedef'd to XXH3_state_t */ + +#undef XXH_ALIGN_MEMBER + +/*! + * @brief Initializes a stack-allocated `XXH3_state_s`. + * + * When the @ref XXH3_state_t structure is merely emplaced on stack, + * it should be initialized with XXH3_INITSTATE() or a memset() + * in case its first reset uses XXH3_NNbits_reset_withSeed(). + * This init can be omitted if the first reset uses default or _withSecret mode. + * This operation isn't necessary when the state is created with XXH3_createState(). + * Note that this doesn't prepare the state for a streaming operation, + * it's still necessary to use XXH3_NNbits_reset*() afterwards. + */ +#define XXH3_INITSTATE(XXH3_state_ptr) \ + do { \ + XXH3_state_t* tmp_xxh3_state_ptr = (XXH3_state_ptr); \ + tmp_xxh3_state_ptr->seed = 0; \ + tmp_xxh3_state_ptr->extSecret = NULL; \ + } while(0) + + +/*! + * @brief Calculates the 128-bit hash of @p data using XXH3. + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param seed The 64-bit seed to alter the hash's output predictably. + * + * @pre + * The memory between @p data and @p data + @p len must be valid, + * readable, contiguous memory. However, if @p len is `0`, @p data may be + * `NULL`. In C++, this also must be *TriviallyCopyable*. + * + * @return The calculated 128-bit XXH3 value. + * + * @see @ref single_shot_example "Single Shot Example" for an example. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t XXH128(XXH_NOESCAPE const void* data, size_t len, XXH64_hash_t seed); + + +/* === Experimental API === */ +/* Symbols defined below must be considered tied to a specific library version. */ + +/*! + * @brief Derive a high-entropy secret from any user-defined content, named customSeed. + * + * @param secretBuffer A writable buffer for derived high-entropy secret data. + * @param secretSize Size of secretBuffer, in bytes. Must be >= XXH3_SECRET_SIZE_MIN. + * @param customSeed A user-defined content. + * @param customSeedSize Size of customSeed, in bytes. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * The generated secret can be used in combination with `*_withSecret()` functions. + * The `_withSecret()` variants are useful to provide a higher level of protection + * than 64-bit seed, as it becomes much more difficult for an external actor to + * guess how to impact the calculation logic. + * + * The function accepts as input a custom seed of any length and any content, + * and derives from it a high-entropy secret of length @p secretSize into an + * already allocated buffer @p secretBuffer. + * + * The generated secret can then be used with any `*_withSecret()` variant. + * The functions @ref XXH3_128bits_withSecret(), @ref XXH3_64bits_withSecret(), + * @ref XXH3_128bits_reset_withSecret() and @ref XXH3_64bits_reset_withSecret() + * are part of this list. They all accept a `secret` parameter + * which must be large enough for implementation reasons (>= @ref XXH3_SECRET_SIZE_MIN) + * _and_ feature very high entropy (consist of random-looking bytes). + * These conditions can be a high bar to meet, so @ref XXH3_generateSecret() can + * be employed to ensure proper quality. + * + * @p customSeed can be anything. It can have any size, even small ones, + * and its content can be anything, even "poor entropy" sources such as a bunch + * of zeroes. The resulting `secret` will nonetheless provide all required qualities. + * + * @pre + * - @p secretSize must be >= @ref XXH3_SECRET_SIZE_MIN + * - When @p customSeedSize > 0, supplying NULL as customSeed is undefined behavior. + * + * Example code: + * @code{.c} + * #include <stdio.h> + * #include <stdlib.h> + * #include <string.h> + * #define XXH_STATIC_LINKING_ONLY // expose unstable API + * #include "xxhash.h" + * // Hashes argv[2] using the entropy from argv[1]. + * int main(int argc, char* argv[]) + * { + * char secret[XXH3_SECRET_SIZE_MIN]; + * if (argv != 3) { return 1; } + * XXH3_generateSecret(secret, sizeof(secret), argv[1], strlen(argv[1])); + * XXH64_hash_t h = XXH3_64bits_withSecret( + * argv[2], strlen(argv[2]), + * secret, sizeof(secret) + * ); + * printf("%016llx\n", (unsigned long long) h); + * } + * @endcode + */ +XXH_PUBLIC_API XXH_errorcode XXH3_generateSecret(XXH_NOESCAPE void* secretBuffer, size_t secretSize, XXH_NOESCAPE const void* customSeed, size_t customSeedSize); + +/*! + * @brief Generate the same secret as the _withSeed() variants. + * + * @param secretBuffer A writable buffer of @ref XXH3_SECRET_DEFAULT_SIZE bytes + * @param seed The 64-bit seed to alter the hash result predictably. + * + * The generated secret can be used in combination with + *`*_withSecret()` and `_withSecretandSeed()` variants. + * + * Example C++ `std::string` hash class: + * @code{.cpp} + * #include <string> + * #define XXH_STATIC_LINKING_ONLY // expose unstable API + * #include "xxhash.h" + * // Slow, seeds each time + * class HashSlow { + * XXH64_hash_t seed; + * public: + * HashSlow(XXH64_hash_t s) : seed{s} {} + * size_t operator()(const std::string& x) const { + * return size_t{XXH3_64bits_withSeed(x.c_str(), x.length(), seed)}; + * } + * }; + * // Fast, caches the seeded secret for future uses. + * class HashFast { + * unsigned char secret[XXH3_SECRET_DEFAULT_SIZE]; + * public: + * HashFast(XXH64_hash_t s) { + * XXH3_generateSecret_fromSeed(secret, seed); + * } + * size_t operator()(const std::string& x) const { + * return size_t{ + * XXH3_64bits_withSecret(x.c_str(), x.length(), secret, sizeof(secret)) + * }; + * } + * }; + * @endcode + */ +XXH_PUBLIC_API void XXH3_generateSecret_fromSeed(XXH_NOESCAPE void* secretBuffer, XXH64_hash_t seed); + +/*! + * @brief Maximum size of "short" key in bytes. + */ +#define XXH3_MIDSIZE_MAX 240 + +/*! + * @brief Calculates 64/128-bit seeded variant of XXH3 hash of @p data. + * + * @param data The block of data to be hashed, at least @p len bytes in size. + * @param len The length of @p data, in bytes. + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * @param seed The 64-bit seed to alter the hash result predictably. + * + * These variants generate hash values using either: + * - @p seed for "short" keys (< @ref XXH3_MIDSIZE_MAX = 240 bytes) + * - @p secret for "large" keys (>= @ref XXH3_MIDSIZE_MAX). + * + * This generally benefits speed, compared to `_withSeed()` or `_withSecret()`. + * `_withSeed()` has to generate the secret on the fly for "large" keys. + * It's fast, but can be perceptible for "not so large" keys (< 1 KB). + * `_withSecret()` has to generate the masks on the fly for "small" keys, + * which requires more instructions than _withSeed() variants. + * Therefore, _withSecretandSeed variant combines the best of both worlds. + * + * When @p secret has been generated by XXH3_generateSecret_fromSeed(), + * this variant produces *exactly* the same results as `_withSeed()` variant, + * hence offering only a pure speed benefit on "large" input, + * by skipping the need to regenerate the secret for every large input. + * + * Another usage scenario is to hash the secret to a 64-bit hash value, + * for example with XXH3_64bits(), which then becomes the seed, + * and then employ both the seed and the secret in _withSecretandSeed(). + * On top of speed, an added benefit is that each bit in the secret + * has a 50% chance to swap each bit in the output, via its impact to the seed. + * + * This is not guaranteed when using the secret directly in "small data" scenarios, + * because only portions of the secret are employed for small data. + */ +XXH_PUBLIC_API XXH_PUREF XXH64_hash_t +XXH3_64bits_withSecretandSeed(XXH_NOESCAPE const void* data, size_t len, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed); + +/*! + * @brief Calculates 128-bit seeded variant of XXH3 hash of @p data. + * + * @param data The memory segment to be hashed, at least @p len bytes in size. + * @param length The length of @p data, in bytes. + * @param secret The secret used to alter hash result predictably. + * @param secretSize The length of @p secret, in bytes (must be >= XXH3_SECRET_SIZE_MIN) + * @param seed64 The 64-bit seed to alter the hash result predictably. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @see XXH3_64bits_withSecretandSeed(): contract is the same. + */ +XXH_PUBLIC_API XXH_PUREF XXH128_hash_t +XXH3_128bits_withSecretandSeed(XXH_NOESCAPE const void* input, size_t length, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed64); + +#ifndef XXH_NO_STREAM +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr A pointer to an @ref XXH3_state_t allocated with @ref XXH3_createState(). + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * @param seed64 The 64-bit seed to alter the hash result predictably. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @see XXH3_64bits_withSecretandSeed(). Contract is identical. + */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed64); + +/*! + * @brief Resets an @ref XXH3_state_t with secret data to begin a new hash. + * + * @param statePtr A pointer to an @ref XXH3_state_t allocated with @ref XXH3_createState(). + * @param secret The secret data. + * @param secretSize The length of @p secret, in bytes. + * @param seed64 The 64-bit seed to alter the hash result predictably. + * + * @return @ref XXH_OK on success. + * @return @ref XXH_ERROR on failure. + * + * @see XXH3_64bits_withSecretandSeed(). Contract is identical. + * + * Note: there was a bug in an earlier version of this function (<= v0.8.2) + * that would make it generate an incorrect hash value + * when @p seed == 0 and @p length < XXH3_MIDSIZE_MAX + * and @p secret is different from XXH3_generateSecret_fromSeed(). + * As stated in the contract, the correct hash result must be + * the same as XXH3_128bits_withSeed() when @p length <= XXH3_MIDSIZE_MAX. + * Results generated by this older version are wrong, hence not comparable. + */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, + XXH_NOESCAPE const void* secret, size_t secretSize, + XXH64_hash_t seed64); + +#endif /* !XXH_NO_STREAM */ + +#endif /* !XXH_NO_XXH3 */ +#endif /* XXH_NO_LONG_LONG */ +#if defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) +# define XXH_IMPLEMENTATION +#endif + +#endif /* defined(XXH_STATIC_LINKING_ONLY) && !defined(XXHASH_H_STATIC_13879238742) */ + + +/* ======================================================================== */ +/* ======================================================================== */ +/* ======================================================================== */ + + +/*-********************************************************************** + * xxHash implementation + *-********************************************************************** + * xxHash's implementation used to be hosted inside xxhash.c. + * + * However, inlining requires implementation to be visible to the compiler, + * hence be included alongside the header. + * Previously, implementation was hosted inside xxhash.c, + * which was then #included when inlining was activated. + * This construction created issues with a few build and install systems, + * as it required xxhash.c to be stored in /include directory. + * + * xxHash implementation is now directly integrated within xxhash.h. + * As a consequence, xxhash.c is no longer needed in /include. + * + * xxhash.c is still available and is still useful. + * In a "normal" setup, when xxhash is not inlined, + * xxhash.h only exposes the prototypes and public symbols, + * while xxhash.c can be built into an object file xxhash.o + * which can then be linked into the final binary. + ************************************************************************/ + +#if ( defined(XXH_INLINE_ALL) || defined(XXH_PRIVATE_API) \ + || defined(XXH_IMPLEMENTATION) ) && !defined(XXH_IMPLEM_13a8737387) +# define XXH_IMPLEM_13a8737387 + +/* ************************************* +* Tuning parameters +***************************************/ + +/*! + * @defgroup tuning Tuning parameters + * @{ + * + * Various macros to control xxHash's behavior. + */ +#ifdef XXH_DOXYGEN +/*! + * @brief Define this to disable 64-bit code. + * + * Useful if only using the @ref XXH32_family and you have a strict C90 compiler. + */ +# define XXH_NO_LONG_LONG +# undef XXH_NO_LONG_LONG /* don't actually */ +/*! + * @brief Controls how unaligned memory is accessed. + * + * By default, access to unaligned memory is controlled by `memcpy()`, which is + * safe and portable. + * + * Unfortunately, on some target/compiler combinations, the generated assembly + * is sub-optimal. + * + * The below switch allow selection of a different access method + * in the search for improved performance. + * + * @par Possible options: + * + * - `XXH_FORCE_MEMORY_ACCESS=0` (default): `memcpy` + * @par + * Use `memcpy()`. Safe and portable. Note that most modern compilers will + * eliminate the function call and treat it as an unaligned access. + * + * - `XXH_FORCE_MEMORY_ACCESS=1`: `__attribute__((aligned(1)))` + * @par + * Depends on compiler extensions and is therefore not portable. + * This method is safe _if_ your compiler supports it, + * and *generally* as fast or faster than `memcpy`. + * + * - `XXH_FORCE_MEMORY_ACCESS=2`: Direct cast + * @par + * Casts directly and dereferences. This method doesn't depend on the + * compiler, but it violates the C standard as it directly dereferences an + * unaligned pointer. It can generate buggy code on targets which do not + * support unaligned memory accesses, but in some circumstances, it's the + * only known way to get the most performance. + * + * - `XXH_FORCE_MEMORY_ACCESS=3`: Byteshift + * @par + * Also portable. This can generate the best code on old compilers which don't + * inline small `memcpy()` calls, and it might also be faster on big-endian + * systems which lack a native byteswap instruction. However, some compilers + * will emit literal byteshifts even if the target supports unaligned access. + * + * + * @warning + * Methods 1 and 2 rely on implementation-defined behavior. Use these with + * care, as what works on one compiler/platform/optimization level may cause + * another to read garbage data or even crash. + * + * See https://fastcompression.blogspot.com/2015/08/accessing-unaligned-memory.html for details. + * + * Prefer these methods in priority order (0 > 3 > 1 > 2) + */ +# define XXH_FORCE_MEMORY_ACCESS 0 + +/*! + * @def XXH_SIZE_OPT + * @brief Controls how much xxHash optimizes for size. + * + * xxHash, when compiled, tends to result in a rather large binary size. This + * is mostly due to heavy usage to forced inlining and constant folding of the + * @ref XXH3_family to increase performance. + * + * However, some developers prefer size over speed. This option can + * significantly reduce the size of the generated code. When using the `-Os` + * or `-Oz` options on GCC or Clang, this is defined to 1 by default, + * otherwise it is defined to 0. + * + * Most of these size optimizations can be controlled manually. + * + * This is a number from 0-2. + * - `XXH_SIZE_OPT` == 0: Default. xxHash makes no size optimizations. Speed + * comes first. + * - `XXH_SIZE_OPT` == 1: Default for `-Os` and `-Oz`. xxHash is more + * conservative and disables hacks that increase code size. It implies the + * options @ref XXH_NO_INLINE_HINTS == 1, @ref XXH_FORCE_ALIGN_CHECK == 0, + * and @ref XXH3_NEON_LANES == 8 if they are not already defined. + * - `XXH_SIZE_OPT` == 2: xxHash tries to make itself as small as possible. + * Performance may cry. For example, the single shot functions just use the + * streaming API. + */ +# define XXH_SIZE_OPT 0 + +/*! + * @def XXH_FORCE_ALIGN_CHECK + * @brief If defined to non-zero, adds a special path for aligned inputs (XXH32() + * and XXH64() only). + * + * This is an important performance trick for architectures without decent + * unaligned memory access performance. + * + * It checks for input alignment, and when conditions are met, uses a "fast + * path" employing direct 32-bit/64-bit reads, resulting in _dramatically + * faster_ read speed. + * + * The check costs one initial branch per hash, which is generally negligible, + * but not zero. + * + * Moreover, it's not useful to generate an additional code path if memory + * access uses the same instruction for both aligned and unaligned + * addresses (e.g. x86 and aarch64). + * + * In these cases, the alignment check can be removed by setting this macro to 0. + * Then the code will always use unaligned memory access. + * Align check is automatically disabled on x86, x64, ARM64, and some ARM chips + * which are platforms known to offer good unaligned memory accesses performance. + * + * It is also disabled by default when @ref XXH_SIZE_OPT >= 1. + * + * This option does not affect XXH3 (only XXH32 and XXH64). + */ +# define XXH_FORCE_ALIGN_CHECK 0 + +/*! + * @def XXH_NO_INLINE_HINTS + * @brief When non-zero, sets all functions to `static`. + * + * By default, xxHash tries to force the compiler to inline almost all internal + * functions. + * + * This can usually improve performance due to reduced jumping and improved + * constant folding, but significantly increases the size of the binary which + * might not be favorable. + * + * Additionally, sometimes the forced inlining can be detrimental to performance, + * depending on the architecture. + * + * XXH_NO_INLINE_HINTS marks all internal functions as static, giving the + * compiler full control on whether to inline or not. + * + * When not optimizing (-O0), using `-fno-inline` with GCC or Clang, or if + * @ref XXH_SIZE_OPT >= 1, this will automatically be defined. + */ +# define XXH_NO_INLINE_HINTS 0 + +/*! + * @def XXH3_INLINE_SECRET + * @brief Determines whether to inline the XXH3 withSecret code. + * + * When the secret size is known, the compiler can improve the performance + * of XXH3_64bits_withSecret() and XXH3_128bits_withSecret(). + * + * However, if the secret size is not known, it doesn't have any benefit. This + * happens when xxHash is compiled into a global symbol. Therefore, if + * @ref XXH_INLINE_ALL is *not* defined, this will be defined to 0. + * + * Additionally, this defaults to 0 on GCC 12+, which has an issue with function pointers + * that are *sometimes* force inline on -Og, and it is impossible to automatically + * detect this optimization level. + */ +# define XXH3_INLINE_SECRET 0 + +/*! + * @def XXH32_ENDJMP + * @brief Whether to use a jump for `XXH32_finalize`. + * + * For performance, `XXH32_finalize` uses multiple branches in the finalizer. + * This is generally preferable for performance, + * but depending on exact architecture, a jmp may be preferable. + * + * This setting is only possibly making a difference for very small inputs. + */ +# define XXH32_ENDJMP 0 + +/*! + * @internal + * @brief Redefines old internal names. + * + * For compatibility with code that uses xxHash's internals before the names + * were changed to improve namespacing. There is no other reason to use this. + */ +# define XXH_OLD_NAMES +# undef XXH_OLD_NAMES /* don't actually use, it is ugly. */ + +/*! + * @def XXH_NO_STREAM + * @brief Disables the streaming API. + * + * When xxHash is not inlined and the streaming functions are not used, disabling + * the streaming functions can improve code size significantly, especially with + * the @ref XXH3_family which tends to make constant folded copies of itself. + */ +# define XXH_NO_STREAM +# undef XXH_NO_STREAM /* don't actually */ +#endif /* XXH_DOXYGEN */ +/*! + * @} + */ + +#ifndef XXH_FORCE_MEMORY_ACCESS /* can be defined externally, on command line for example */ + /* prefer __packed__ structures (method 1) for GCC + * < ARMv7 with unaligned access (e.g. Raspbian armhf) still uses byte shifting, so we use memcpy + * which for some reason does unaligned loads. */ +# if defined(__GNUC__) && !(defined(__ARM_ARCH) && __ARM_ARCH < 7 && defined(__ARM_FEATURE_UNALIGNED)) +# define XXH_FORCE_MEMORY_ACCESS 1 +# endif +#endif + +#ifndef XXH_SIZE_OPT + /* default to 1 for -Os or -Oz */ +# if (defined(__GNUC__) || defined(__clang__)) && defined(__OPTIMIZE_SIZE__) +# define XXH_SIZE_OPT 1 +# else +# define XXH_SIZE_OPT 0 +# endif +#endif + +#ifndef XXH_FORCE_ALIGN_CHECK /* can be defined externally */ + /* don't check on sizeopt, x86, aarch64, or arm when unaligned access is available */ +# if XXH_SIZE_OPT >= 1 || \ + defined(__i386) || defined(__x86_64__) || defined(__aarch64__) || defined(__ARM_FEATURE_UNALIGNED) \ + || defined(_M_IX86) || defined(_M_X64) || defined(_M_ARM64) || defined(_M_ARM) /* visual */ +# define XXH_FORCE_ALIGN_CHECK 0 +# else +# define XXH_FORCE_ALIGN_CHECK 1 +# endif +#endif + +#ifndef XXH_NO_INLINE_HINTS +# if XXH_SIZE_OPT >= 1 || defined(__NO_INLINE__) /* -O0, -fno-inline */ +# define XXH_NO_INLINE_HINTS 1 +# else +# define XXH_NO_INLINE_HINTS 0 +# endif +#endif + +#ifndef XXH3_INLINE_SECRET +# if (defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 12) \ + || !defined(XXH_INLINE_ALL) +# define XXH3_INLINE_SECRET 0 +# else +# define XXH3_INLINE_SECRET 1 +# endif +#endif + +#ifndef XXH32_ENDJMP +/* generally preferable for performance */ +# define XXH32_ENDJMP 0 +#endif + +/*! + * @defgroup impl Implementation + * @{ + */ + + +/* ************************************* +* Includes & Memory related functions +***************************************/ +#if defined(XXH_NO_STREAM) +/* nothing */ +#elif defined(XXH_NO_STDLIB) + +/* When requesting to disable any mention of stdlib, + * the library loses the ability to invoked malloc / free. + * In practice, it means that functions like `XXH*_createState()` + * will always fail, and return NULL. + * This flag is useful in situations where + * xxhash.h is integrated into some kernel, embedded or limited environment + * without access to dynamic allocation. + */ + +static XXH_CONSTF void* XXH_malloc(size_t s) { (void)s; return NULL; } +static void XXH_free(void* p) { (void)p; } + +#else + +/* + * Modify the local functions below should you wish to use + * different memory routines for malloc() and free() + */ +#include <stdlib.h> + +/*! + * @internal + * @brief Modify this function to use a different routine than malloc(). + */ +static XXH_MALLOCF void* XXH_malloc(size_t s) { return malloc(s); } + +/*! + * @internal + * @brief Modify this function to use a different routine than free(). + */ +static void XXH_free(void* p) { free(p); } + +#endif /* XXH_NO_STDLIB */ + +#include <string.h> + +/*! + * @internal + * @brief Modify this function to use a different routine than memcpy(). + */ +static void* XXH_memcpy(void* dest, const void* src, size_t size) +{ + return memcpy(dest,src,size); +} + +#include <limits.h> /* ULLONG_MAX */ + + +/* ************************************* +* Compiler Specific Options +***************************************/ +#ifdef _MSC_VER /* Visual Studio warning fix */ +# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */ +#endif + +#if XXH_NO_INLINE_HINTS /* disable inlining hints */ +# if defined(__GNUC__) || defined(__clang__) +# define XXH_FORCE_INLINE static __attribute__((__unused__)) +# else +# define XXH_FORCE_INLINE static +# endif +# define XXH_NO_INLINE static +/* enable inlining hints */ +#elif defined(__GNUC__) || defined(__clang__) +# define XXH_FORCE_INLINE static __inline__ __attribute__((__always_inline__, __unused__)) +# define XXH_NO_INLINE static __attribute__((__noinline__)) +#elif defined(_MSC_VER) /* Visual Studio */ +# define XXH_FORCE_INLINE static __forceinline +# define XXH_NO_INLINE static __declspec(noinline) +#elif defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) /* C99 */ +# define XXH_FORCE_INLINE static inline +# define XXH_NO_INLINE static +#else +# define XXH_FORCE_INLINE static +# define XXH_NO_INLINE static +#endif + +#if XXH3_INLINE_SECRET +# define XXH3_WITH_SECRET_INLINE XXH_FORCE_INLINE +#else +# define XXH3_WITH_SECRET_INLINE XXH_NO_INLINE +#endif + + +/* ************************************* +* Debug +***************************************/ +/*! + * @ingroup tuning + * @def XXH_DEBUGLEVEL + * @brief Sets the debugging level. + * + * XXH_DEBUGLEVEL is expected to be defined externally, typically via the + * compiler's command line options. The value must be a number. + */ +#ifndef XXH_DEBUGLEVEL +# ifdef DEBUGLEVEL /* backwards compat */ +# define XXH_DEBUGLEVEL DEBUGLEVEL +# else +# define XXH_DEBUGLEVEL 0 +# endif +#endif + +#if (XXH_DEBUGLEVEL>=1) +# include <assert.h> /* note: can still be disabled with NDEBUG */ +# define XXH_ASSERT(c) assert(c) +#else +# if defined(__INTEL_COMPILER) +# define XXH_ASSERT(c) XXH_ASSUME((unsigned char) (c)) +# else +# define XXH_ASSERT(c) XXH_ASSUME(c) +# endif +#endif + +/* note: use after variable declarations */ +#ifndef XXH_STATIC_ASSERT +# if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) /* C11 */ +# define XXH_STATIC_ASSERT_WITH_MESSAGE(c,m) do { _Static_assert((c),m); } while(0) +# elif defined(__cplusplus) && (__cplusplus >= 201103L) /* C++11 */ +# define XXH_STATIC_ASSERT_WITH_MESSAGE(c,m) do { static_assert((c),m); } while(0) +# else +# define XXH_STATIC_ASSERT_WITH_MESSAGE(c,m) do { struct xxh_sa { char x[(c) ? 1 : -1]; }; } while(0) +# endif +# define XXH_STATIC_ASSERT(c) XXH_STATIC_ASSERT_WITH_MESSAGE((c),#c) +#endif + +/*! + * @internal + * @def XXH_COMPILER_GUARD(var) + * @brief Used to prevent unwanted optimizations for @p var. + * + * It uses an empty GCC inline assembly statement with a register constraint + * which forces @p var into a general purpose register (eg eax, ebx, ecx + * on x86) and marks it as modified. + * + * This is used in a few places to avoid unwanted autovectorization (e.g. + * XXH32_round()). All vectorization we want is explicit via intrinsics, + * and _usually_ isn't wanted elsewhere. + * + * We also use it to prevent unwanted constant folding for AArch64 in + * XXH3_initCustomSecret_scalar(). + */ +#if defined(__GNUC__) || defined(__clang__) +# define XXH_COMPILER_GUARD(var) __asm__("" : "+r" (var)) +#else +# define XXH_COMPILER_GUARD(var) ((void)0) +#endif + +/* Specifically for NEON vectors which use the "w" constraint, on + * Clang. */ +#if defined(__clang__) && defined(__ARM_ARCH) && !defined(__wasm__) +# define XXH_COMPILER_GUARD_CLANG_NEON(var) __asm__("" : "+w" (var)) +#else +# define XXH_COMPILER_GUARD_CLANG_NEON(var) ((void)0) +#endif + +/* ************************************* +* Basic Types +***************************************/ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) +# ifdef _AIX +# include <inttypes.h> +# else +# include <stdint.h> +# endif + typedef uint8_t xxh_u8; +#else + typedef unsigned char xxh_u8; +#endif +typedef XXH32_hash_t xxh_u32; + +#ifdef XXH_OLD_NAMES +# warning "XXH_OLD_NAMES is planned to be removed starting v0.9. If the program depends on it, consider moving away from it by employing newer type names directly" +# define BYTE xxh_u8 +# define U8 xxh_u8 +# define U32 xxh_u32 +#endif + +/* *** Memory access *** */ + +/*! + * @internal + * @fn xxh_u32 XXH_read32(const void* ptr) + * @brief Reads an unaligned 32-bit integer from @p ptr in native endianness. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * + * @param ptr The pointer to read from. + * @return The 32-bit native endian integer from the bytes at @p ptr. + */ + +/*! + * @internal + * @fn xxh_u32 XXH_readLE32(const void* ptr) + * @brief Reads an unaligned 32-bit little endian integer from @p ptr. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * + * @param ptr The pointer to read from. + * @return The 32-bit little endian integer from the bytes at @p ptr. + */ + +/*! + * @internal + * @fn xxh_u32 XXH_readBE32(const void* ptr) + * @brief Reads an unaligned 32-bit big endian integer from @p ptr. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * + * @param ptr The pointer to read from. + * @return The 32-bit big endian integer from the bytes at @p ptr. + */ + +/*! + * @internal + * @fn xxh_u32 XXH_readLE32_align(const void* ptr, XXH_alignment align) + * @brief Like @ref XXH_readLE32(), but has an option for aligned reads. + * + * Affected by @ref XXH_FORCE_MEMORY_ACCESS. + * Note that when @ref XXH_FORCE_ALIGN_CHECK == 0, the @p align parameter is + * always @ref XXH_alignment::XXH_unaligned. + * + * @param ptr The pointer to read from. + * @param align Whether @p ptr is aligned. + * @pre + * If @p align == @ref XXH_alignment::XXH_aligned, @p ptr must be 4 byte + * aligned. + * @return The 32-bit little endian integer from the bytes at @p ptr. + */ + +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) +/* + * Manual byteshift. Best for old compilers which don't inline memcpy. + * We actually directly use XXH_readLE32 and XXH_readBE32. + */ +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2)) + +/* + * Force direct memory access. Only works on CPU which support unaligned memory + * access in hardware. + */ +static xxh_u32 XXH_read32(const void* memPtr) { return *(const xxh_u32*) memPtr; } + +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1)) + +/* + * __attribute__((aligned(1))) is supported by gcc and clang. Originally the + * documentation claimed that it only increased the alignment, but actually it + * can decrease it on gcc, clang, and icc: + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=69502, + * https://gcc.godbolt.org/z/xYez1j67Y. + */ +#ifdef XXH_OLD_NAMES +typedef union { xxh_u32 u32; } __attribute__((__packed__)) unalign; +#endif +static xxh_u32 XXH_read32(const void* ptr) +{ + typedef __attribute__((__aligned__(1))) xxh_u32 xxh_unalign32; + return *((const xxh_unalign32*)ptr); +} + +#else + +/* + * Portable and safe solution. Generally efficient. + * see: https://fastcompression.blogspot.com/2015/08/accessing-unaligned-memory.html + */ +static xxh_u32 XXH_read32(const void* memPtr) +{ + xxh_u32 val; + XXH_memcpy(&val, memPtr, sizeof(val)); + return val; +} + +#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */ + + +/* *** Endianness *** */ + +/*! + * @ingroup tuning + * @def XXH_CPU_LITTLE_ENDIAN + * @brief Whether the target is little endian. + * + * Defined to 1 if the target is little endian, or 0 if it is big endian. + * It can be defined externally, for example on the compiler command line. + * + * If it is not defined, + * a runtime check (which is usually constant folded) is used instead. + * + * @note + * This is not necessarily defined to an integer constant. + * + * @see XXH_isLittleEndian() for the runtime check. + */ +#ifndef XXH_CPU_LITTLE_ENDIAN +/* + * Try to detect endianness automatically, to avoid the nonstandard behavior + * in `XXH_isLittleEndian()` + */ +# if defined(_WIN32) /* Windows is always little endian */ \ + || defined(__LITTLE_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) +# define XXH_CPU_LITTLE_ENDIAN 1 +# elif defined(__BIG_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +# define XXH_CPU_LITTLE_ENDIAN 0 +# else +/*! + * @internal + * @brief Runtime check for @ref XXH_CPU_LITTLE_ENDIAN. + * + * Most compilers will constant fold this. + */ +static int XXH_isLittleEndian(void) +{ + /* + * Portable and well-defined behavior. + * Don't use static: it is detrimental to performance. + */ + const union { xxh_u32 u; xxh_u8 c[4]; } one = { 1 }; + return one.c[0]; +} +# define XXH_CPU_LITTLE_ENDIAN XXH_isLittleEndian() +# endif +#endif + + + + +/* **************************************** +* Compiler-specific Functions and Macros +******************************************/ +#define XXH_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#ifdef __has_builtin +# define XXH_HAS_BUILTIN(x) __has_builtin(x) +#else +# define XXH_HAS_BUILTIN(x) 0 +#endif + + + +/* + * C23 and future versions have standard "unreachable()". + * Once it has been implemented reliably we can add it as an + * additional case: + * + * ``` + * #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= XXH_C23_VN) + * # include <stddef.h> + * # ifdef unreachable + * # define XXH_UNREACHABLE() unreachable() + * # endif + * #endif + * ``` + * + * Note C++23 also has std::unreachable() which can be detected + * as follows: + * ``` + * #if defined(__cpp_lib_unreachable) && (__cpp_lib_unreachable >= 202202L) + * # include <utility> + * # define XXH_UNREACHABLE() std::unreachable() + * #endif + * ``` + * NB: `__cpp_lib_unreachable` is defined in the `<version>` header. + * We don't use that as including `<utility>` in `extern "C"` blocks + * doesn't work on GCC12 + */ + +#if XXH_HAS_BUILTIN(__builtin_unreachable) +# define XXH_UNREACHABLE() __builtin_unreachable() + +#elif defined(_MSC_VER) +# define XXH_UNREACHABLE() __assume(0) + +#else +# define XXH_UNREACHABLE() +#endif + +#if XXH_HAS_BUILTIN(__builtin_assume) +# define XXH_ASSUME(c) __builtin_assume(c) +#else +# define XXH_ASSUME(c) if (!(c)) { XXH_UNREACHABLE(); } +#endif + +/*! + * @internal + * @def XXH_rotl32(x,r) + * @brief 32-bit rotate left. + * + * @param x The 32-bit integer to be rotated. + * @param r The number of bits to rotate. + * @pre + * @p r > 0 && @p r < 32 + * @note + * @p x and @p r may be evaluated multiple times. + * @return The rotated result. + */ +#if !defined(NO_CLANG_BUILTIN) && XXH_HAS_BUILTIN(__builtin_rotateleft32) \ + && XXH_HAS_BUILTIN(__builtin_rotateleft64) +# define XXH_rotl32 __builtin_rotateleft32 +# define XXH_rotl64 __builtin_rotateleft64 +/* Note: although _rotl exists for minGW (GCC under windows), performance seems poor */ +#elif defined(_MSC_VER) +# define XXH_rotl32(x,r) _rotl(x,r) +# define XXH_rotl64(x,r) _rotl64(x,r) +#else +# define XXH_rotl32(x,r) (((x) << (r)) | ((x) >> (32 - (r)))) +# define XXH_rotl64(x,r) (((x) << (r)) | ((x) >> (64 - (r)))) +#endif + +/*! + * @internal + * @fn xxh_u32 XXH_swap32(xxh_u32 x) + * @brief A 32-bit byteswap. + * + * @param x The 32-bit integer to byteswap. + * @return @p x, byteswapped. + */ +#if defined(_MSC_VER) /* Visual Studio */ +# define XXH_swap32 _byteswap_ulong +#elif XXH_GCC_VERSION >= 403 +# define XXH_swap32 __builtin_bswap32 +#else +static xxh_u32 XXH_swap32 (xxh_u32 x) +{ + return ((x << 24) & 0xff000000 ) | + ((x << 8) & 0x00ff0000 ) | + ((x >> 8) & 0x0000ff00 ) | + ((x >> 24) & 0x000000ff ); +} +#endif + + +/* *************************** +* Memory reads +*****************************/ + +/*! + * @internal + * @brief Enum to indicate whether a pointer is aligned. + */ +typedef enum { + XXH_aligned, /*!< Aligned */ + XXH_unaligned /*!< Possibly unaligned */ +} XXH_alignment; + +/* + * XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. + * + * This is ideal for older compilers which don't inline memcpy. + */ +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) + +XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[0] + | ((xxh_u32)bytePtr[1] << 8) + | ((xxh_u32)bytePtr[2] << 16) + | ((xxh_u32)bytePtr[3] << 24); +} + +XXH_FORCE_INLINE xxh_u32 XXH_readBE32(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[3] + | ((xxh_u32)bytePtr[2] << 8) + | ((xxh_u32)bytePtr[1] << 16) + | ((xxh_u32)bytePtr[0] << 24); +} + +#else +XXH_FORCE_INLINE xxh_u32 XXH_readLE32(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_read32(ptr) : XXH_swap32(XXH_read32(ptr)); +} + +static xxh_u32 XXH_readBE32(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_swap32(XXH_read32(ptr)) : XXH_read32(ptr); +} +#endif + +XXH_FORCE_INLINE xxh_u32 +XXH_readLE32_align(const void* ptr, XXH_alignment align) +{ + if (align==XXH_unaligned) { + return XXH_readLE32(ptr); + } else { + return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u32*)ptr : XXH_swap32(*(const xxh_u32*)ptr); + } +} + + +/* ************************************* +* Misc +***************************************/ +/*! @ingroup public */ +XXH_PUBLIC_API unsigned XXH_versionNumber (void) { return XXH_VERSION_NUMBER; } + + +/* ******************************************************************* +* 32-bit hash functions +*********************************************************************/ +/*! + * @} + * @defgroup XXH32_impl XXH32 implementation + * @ingroup impl + * + * Details on the XXH32 implementation. + * @{ + */ + /* #define instead of static const, to be used as initializers */ +#define XXH_PRIME32_1 0x9E3779B1U /*!< 0b10011110001101110111100110110001 */ +#define XXH_PRIME32_2 0x85EBCA77U /*!< 0b10000101111010111100101001110111 */ +#define XXH_PRIME32_3 0xC2B2AE3DU /*!< 0b11000010101100101010111000111101 */ +#define XXH_PRIME32_4 0x27D4EB2FU /*!< 0b00100111110101001110101100101111 */ +#define XXH_PRIME32_5 0x165667B1U /*!< 0b00010110010101100110011110110001 */ + +#ifdef XXH_OLD_NAMES +# define PRIME32_1 XXH_PRIME32_1 +# define PRIME32_2 XXH_PRIME32_2 +# define PRIME32_3 XXH_PRIME32_3 +# define PRIME32_4 XXH_PRIME32_4 +# define PRIME32_5 XXH_PRIME32_5 +#endif + +/*! + * @internal + * @brief Normal stripe processing routine. + * + * This shuffles the bits so that any bit from @p input impacts several bits in + * @p acc. + * + * @param acc The accumulator lane. + * @param input The stripe of input to mix. + * @return The mixed accumulator lane. + */ +static xxh_u32 XXH32_round(xxh_u32 acc, xxh_u32 input) +{ + acc += input * XXH_PRIME32_2; + acc = XXH_rotl32(acc, 13); + acc *= XXH_PRIME32_1; +#if (defined(__SSE4_1__) || defined(__aarch64__) || defined(__wasm_simd128__)) && !defined(XXH_ENABLE_AUTOVECTORIZE) + /* + * UGLY HACK: + * A compiler fence is used to prevent GCC and Clang from + * autovectorizing the XXH32 loop (pragmas and attributes don't work for some + * reason) without globally disabling SSE4.1. + * + * The reason we want to avoid vectorization is because despite working on + * 4 integers at a time, there are multiple factors slowing XXH32 down on + * SSE4: + * - There's a ridiculous amount of lag from pmulld (10 cycles of latency on + * newer chips!) making it slightly slower to multiply four integers at + * once compared to four integers independently. Even when pmulld was + * fastest, Sandy/Ivy Bridge, it is still not worth it to go into SSE + * just to multiply unless doing a long operation. + * + * - Four instructions are required to rotate, + * movqda tmp, v // not required with VEX encoding + * pslld tmp, 13 // tmp <<= 13 + * psrld v, 19 // x >>= 19 + * por v, tmp // x |= tmp + * compared to one for scalar: + * roll v, 13 // reliably fast across the board + * shldl v, v, 13 // Sandy Bridge and later prefer this for some reason + * + * - Instruction level parallelism is actually more beneficial here because + * the SIMD actually serializes this operation: While v1 is rotating, v2 + * can load data, while v3 can multiply. SSE forces them to operate + * together. + * + * This is also enabled on AArch64, as Clang is *very aggressive* in vectorizing + * the loop. NEON is only faster on the A53, and with the newer cores, it is less + * than half the speed. + * + * Additionally, this is used on WASM SIMD128 because it JITs to the same + * SIMD instructions and has the same issue. + */ + XXH_COMPILER_GUARD(acc); +#endif + return acc; +} + +/*! + * @internal + * @brief Mixes all bits to finalize the hash. + * + * The final mix ensures that all input bits have a chance to impact any bit in + * the output digest, resulting in an unbiased distribution. + * + * @param hash The hash to avalanche. + * @return The avalanched hash. + */ +static xxh_u32 XXH32_avalanche(xxh_u32 hash) +{ + hash ^= hash >> 15; + hash *= XXH_PRIME32_2; + hash ^= hash >> 13; + hash *= XXH_PRIME32_3; + hash ^= hash >> 16; + return hash; +} + +#define XXH_get32bits(p) XXH_readLE32_align(p, align) + +/*! + * @internal + * @brief Processes the last 0-15 bytes of @p ptr. + * + * There may be up to 15 bytes remaining to consume from the input. + * This final stage will digest them to ensure that all input bytes are present + * in the final mix. + * + * @param hash The hash to finalize. + * @param ptr The pointer to the remaining input. + * @param len The remaining length, modulo 16. + * @param align Whether @p ptr is aligned. + * @return The finalized hash. + * @see XXH64_finalize(). + */ +static XXH_PUREF xxh_u32 +XXH32_finalize(xxh_u32 hash, const xxh_u8* ptr, size_t len, XXH_alignment align) +{ +#define XXH_PROCESS1 do { \ + hash += (*ptr++) * XXH_PRIME32_5; \ + hash = XXH_rotl32(hash, 11) * XXH_PRIME32_1; \ +} while (0) + +#define XXH_PROCESS4 do { \ + hash += XXH_get32bits(ptr) * XXH_PRIME32_3; \ + ptr += 4; \ + hash = XXH_rotl32(hash, 17) * XXH_PRIME32_4; \ +} while (0) + + if (ptr==NULL) XXH_ASSERT(len == 0); + + /* Compact rerolled version; generally faster */ + if (!XXH32_ENDJMP) { + len &= 15; + while (len >= 4) { + XXH_PROCESS4; + len -= 4; + } + while (len > 0) { + XXH_PROCESS1; + --len; + } + return XXH32_avalanche(hash); + } else { + switch(len&15) /* or switch(bEnd - p) */ { + case 12: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 8: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 4: XXH_PROCESS4; + return XXH32_avalanche(hash); + + case 13: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 9: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 5: XXH_PROCESS4; + XXH_PROCESS1; + return XXH32_avalanche(hash); + + case 14: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 10: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 6: XXH_PROCESS4; + XXH_PROCESS1; + XXH_PROCESS1; + return XXH32_avalanche(hash); + + case 15: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 11: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 7: XXH_PROCESS4; + XXH_FALLTHROUGH; /* fallthrough */ + case 3: XXH_PROCESS1; + XXH_FALLTHROUGH; /* fallthrough */ + case 2: XXH_PROCESS1; + XXH_FALLTHROUGH; /* fallthrough */ + case 1: XXH_PROCESS1; + XXH_FALLTHROUGH; /* fallthrough */ + case 0: return XXH32_avalanche(hash); + } + XXH_ASSERT(0); + return hash; /* reaching this point is deemed impossible */ + } +} + +#ifdef XXH_OLD_NAMES +# define PROCESS1 XXH_PROCESS1 +# define PROCESS4 XXH_PROCESS4 +#else +# undef XXH_PROCESS1 +# undef XXH_PROCESS4 +#endif + +/*! + * @internal + * @brief The implementation for @ref XXH32(). + * + * @param input , len , seed Directly passed from @ref XXH32(). + * @param align Whether @p input is aligned. + * @return The calculated hash. + */ +XXH_FORCE_INLINE XXH_PUREF xxh_u32 +XXH32_endian_align(const xxh_u8* input, size_t len, xxh_u32 seed, XXH_alignment align) +{ + xxh_u32 h32; + + if (input==NULL) XXH_ASSERT(len == 0); + + if (len>=16) { + const xxh_u8* const bEnd = input + len; + const xxh_u8* const limit = bEnd - 15; + xxh_u32 v1 = seed + XXH_PRIME32_1 + XXH_PRIME32_2; + xxh_u32 v2 = seed + XXH_PRIME32_2; + xxh_u32 v3 = seed + 0; + xxh_u32 v4 = seed - XXH_PRIME32_1; + + do { + v1 = XXH32_round(v1, XXH_get32bits(input)); input += 4; + v2 = XXH32_round(v2, XXH_get32bits(input)); input += 4; + v3 = XXH32_round(v3, XXH_get32bits(input)); input += 4; + v4 = XXH32_round(v4, XXH_get32bits(input)); input += 4; + } while (input < limit); + + h32 = XXH_rotl32(v1, 1) + XXH_rotl32(v2, 7) + + XXH_rotl32(v3, 12) + XXH_rotl32(v4, 18); + } else { + h32 = seed + XXH_PRIME32_5; + } + + h32 += (xxh_u32)len; + + return XXH32_finalize(h32, input, len&15, align); +} + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_hash_t XXH32 (const void* input, size_t len, XXH32_hash_t seed) +{ +#if !defined(XXH_NO_STREAM) && XXH_SIZE_OPT >= 2 + /* Simple version, good for code maintenance, but unfortunately slow for small inputs */ + XXH32_state_t state; + XXH32_reset(&state, seed); + XXH32_update(&state, (const xxh_u8*)input, len); + return XXH32_digest(&state); +#else + if (XXH_FORCE_ALIGN_CHECK) { + if ((((size_t)input) & 3) == 0) { /* Input is 4-bytes aligned, leverage the speed benefit */ + return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_aligned); + } } + + return XXH32_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned); +#endif +} + + + +/******* Hash streaming *******/ +#ifndef XXH_NO_STREAM +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_state_t* XXH32_createState(void) +{ + return (XXH32_state_t*)XXH_malloc(sizeof(XXH32_state_t)); +} +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH_errorcode XXH32_freeState(XXH32_state_t* statePtr) +{ + XXH_free(statePtr); + return XXH_OK; +} + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API void XXH32_copyState(XXH32_state_t* dstState, const XXH32_state_t* srcState) +{ + XXH_memcpy(dstState, srcState, sizeof(*dstState)); +} + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH_errorcode XXH32_reset(XXH32_state_t* statePtr, XXH32_hash_t seed) +{ + XXH_ASSERT(statePtr != NULL); + memset(statePtr, 0, sizeof(*statePtr)); + statePtr->v[0] = seed + XXH_PRIME32_1 + XXH_PRIME32_2; + statePtr->v[1] = seed + XXH_PRIME32_2; + statePtr->v[2] = seed + 0; + statePtr->v[3] = seed - XXH_PRIME32_1; + return XXH_OK; +} + + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH_errorcode +XXH32_update(XXH32_state_t* state, const void* input, size_t len) +{ + if (input==NULL) { + XXH_ASSERT(len == 0); + return XXH_OK; + } + + { const xxh_u8* p = (const xxh_u8*)input; + const xxh_u8* const bEnd = p + len; + + state->total_len_32 += (XXH32_hash_t)len; + state->large_len |= (XXH32_hash_t)((len>=16) | (state->total_len_32>=16)); + + if (state->memsize + len < 16) { /* fill in tmp buffer */ + XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, len); + state->memsize += (XXH32_hash_t)len; + return XXH_OK; + } + + if (state->memsize) { /* some data left from previous update */ + XXH_memcpy((xxh_u8*)(state->mem32) + state->memsize, input, 16-state->memsize); + { const xxh_u32* p32 = state->mem32; + state->v[0] = XXH32_round(state->v[0], XXH_readLE32(p32)); p32++; + state->v[1] = XXH32_round(state->v[1], XXH_readLE32(p32)); p32++; + state->v[2] = XXH32_round(state->v[2], XXH_readLE32(p32)); p32++; + state->v[3] = XXH32_round(state->v[3], XXH_readLE32(p32)); + } + p += 16-state->memsize; + state->memsize = 0; + } + + if (p <= bEnd-16) { + const xxh_u8* const limit = bEnd - 16; + + do { + state->v[0] = XXH32_round(state->v[0], XXH_readLE32(p)); p+=4; + state->v[1] = XXH32_round(state->v[1], XXH_readLE32(p)); p+=4; + state->v[2] = XXH32_round(state->v[2], XXH_readLE32(p)); p+=4; + state->v[3] = XXH32_round(state->v[3], XXH_readLE32(p)); p+=4; + } while (p<=limit); + + } + + if (p < bEnd) { + XXH_memcpy(state->mem32, p, (size_t)(bEnd-p)); + state->memsize = (unsigned)(bEnd-p); + } + } + + return XXH_OK; +} + + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_hash_t XXH32_digest(const XXH32_state_t* state) +{ + xxh_u32 h32; + + if (state->large_len) { + h32 = XXH_rotl32(state->v[0], 1) + + XXH_rotl32(state->v[1], 7) + + XXH_rotl32(state->v[2], 12) + + XXH_rotl32(state->v[3], 18); + } else { + h32 = state->v[2] /* == seed */ + XXH_PRIME32_5; + } + + h32 += state->total_len_32; + + return XXH32_finalize(h32, (const xxh_u8*)state->mem32, state->memsize, XXH_aligned); +} +#endif /* !XXH_NO_STREAM */ + +/******* Canonical representation *******/ + +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API void XXH32_canonicalFromHash(XXH32_canonical_t* dst, XXH32_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH32_canonical_t) == sizeof(XXH32_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap32(hash); + XXH_memcpy(dst, &hash, sizeof(*dst)); +} +/*! @ingroup XXH32_family */ +XXH_PUBLIC_API XXH32_hash_t XXH32_hashFromCanonical(const XXH32_canonical_t* src) +{ + return XXH_readBE32(src); +} + + +#ifndef XXH_NO_LONG_LONG + +/* ******************************************************************* +* 64-bit hash functions +*********************************************************************/ +/*! + * @} + * @ingroup impl + * @{ + */ +/******* Memory access *******/ + +typedef XXH64_hash_t xxh_u64; + +#ifdef XXH_OLD_NAMES +# define U64 xxh_u64 +#endif + +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) +/* + * Manual byteshift. Best for old compilers which don't inline memcpy. + * We actually directly use XXH_readLE64 and XXH_readBE64. + */ +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==2)) + +/* Force direct memory access. Only works on CPU which support unaligned memory access in hardware */ +static xxh_u64 XXH_read64(const void* memPtr) +{ + return *(const xxh_u64*) memPtr; +} + +#elif (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==1)) + +/* + * __attribute__((aligned(1))) is supported by gcc and clang. Originally the + * documentation claimed that it only increased the alignment, but actually it + * can decrease it on gcc, clang, and icc: + * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=69502, + * https://gcc.godbolt.org/z/xYez1j67Y. + */ +#ifdef XXH_OLD_NAMES +typedef union { xxh_u32 u32; xxh_u64 u64; } __attribute__((__packed__)) unalign64; +#endif +static xxh_u64 XXH_read64(const void* ptr) +{ + typedef __attribute__((__aligned__(1))) xxh_u64 xxh_unalign64; + return *((const xxh_unalign64*)ptr); +} + +#else + +/* + * Portable and safe solution. Generally efficient. + * see: https://fastcompression.blogspot.com/2015/08/accessing-unaligned-memory.html + */ +static xxh_u64 XXH_read64(const void* memPtr) +{ + xxh_u64 val; + XXH_memcpy(&val, memPtr, sizeof(val)); + return val; +} + +#endif /* XXH_FORCE_DIRECT_MEMORY_ACCESS */ + +#if defined(_MSC_VER) /* Visual Studio */ +# define XXH_swap64 _byteswap_uint64 +#elif XXH_GCC_VERSION >= 403 +# define XXH_swap64 __builtin_bswap64 +#else +static xxh_u64 XXH_swap64(xxh_u64 x) +{ + return ((x << 56) & 0xff00000000000000ULL) | + ((x << 40) & 0x00ff000000000000ULL) | + ((x << 24) & 0x0000ff0000000000ULL) | + ((x << 8) & 0x000000ff00000000ULL) | + ((x >> 8) & 0x00000000ff000000ULL) | + ((x >> 24) & 0x0000000000ff0000ULL) | + ((x >> 40) & 0x000000000000ff00ULL) | + ((x >> 56) & 0x00000000000000ffULL); +} +#endif + + +/* XXH_FORCE_MEMORY_ACCESS==3 is an endian-independent byteshift load. */ +#if (defined(XXH_FORCE_MEMORY_ACCESS) && (XXH_FORCE_MEMORY_ACCESS==3)) + +XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[0] + | ((xxh_u64)bytePtr[1] << 8) + | ((xxh_u64)bytePtr[2] << 16) + | ((xxh_u64)bytePtr[3] << 24) + | ((xxh_u64)bytePtr[4] << 32) + | ((xxh_u64)bytePtr[5] << 40) + | ((xxh_u64)bytePtr[6] << 48) + | ((xxh_u64)bytePtr[7] << 56); +} + +XXH_FORCE_INLINE xxh_u64 XXH_readBE64(const void* memPtr) +{ + const xxh_u8* bytePtr = (const xxh_u8 *)memPtr; + return bytePtr[7] + | ((xxh_u64)bytePtr[6] << 8) + | ((xxh_u64)bytePtr[5] << 16) + | ((xxh_u64)bytePtr[4] << 24) + | ((xxh_u64)bytePtr[3] << 32) + | ((xxh_u64)bytePtr[2] << 40) + | ((xxh_u64)bytePtr[1] << 48) + | ((xxh_u64)bytePtr[0] << 56); +} + +#else +XXH_FORCE_INLINE xxh_u64 XXH_readLE64(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_read64(ptr) : XXH_swap64(XXH_read64(ptr)); +} + +static xxh_u64 XXH_readBE64(const void* ptr) +{ + return XXH_CPU_LITTLE_ENDIAN ? XXH_swap64(XXH_read64(ptr)) : XXH_read64(ptr); +} +#endif + +XXH_FORCE_INLINE xxh_u64 +XXH_readLE64_align(const void* ptr, XXH_alignment align) +{ + if (align==XXH_unaligned) + return XXH_readLE64(ptr); + else + return XXH_CPU_LITTLE_ENDIAN ? *(const xxh_u64*)ptr : XXH_swap64(*(const xxh_u64*)ptr); +} + + +/******* xxh64 *******/ +/*! + * @} + * @defgroup XXH64_impl XXH64 implementation + * @ingroup impl + * + * Details on the XXH64 implementation. + * @{ + */ +/* #define rather that static const, to be used as initializers */ +#define XXH_PRIME64_1 0x9E3779B185EBCA87ULL /*!< 0b1001111000110111011110011011000110000101111010111100101010000111 */ +#define XXH_PRIME64_2 0xC2B2AE3D27D4EB4FULL /*!< 0b1100001010110010101011100011110100100111110101001110101101001111 */ +#define XXH_PRIME64_3 0x165667B19E3779F9ULL /*!< 0b0001011001010110011001111011000110011110001101110111100111111001 */ +#define XXH_PRIME64_4 0x85EBCA77C2B2AE63ULL /*!< 0b1000010111101011110010100111011111000010101100101010111001100011 */ +#define XXH_PRIME64_5 0x27D4EB2F165667C5ULL /*!< 0b0010011111010100111010110010111100010110010101100110011111000101 */ + +#ifdef XXH_OLD_NAMES +# define PRIME64_1 XXH_PRIME64_1 +# define PRIME64_2 XXH_PRIME64_2 +# define PRIME64_3 XXH_PRIME64_3 +# define PRIME64_4 XXH_PRIME64_4 +# define PRIME64_5 XXH_PRIME64_5 +#endif + +/*! @copydoc XXH32_round */ +static xxh_u64 XXH64_round(xxh_u64 acc, xxh_u64 input) +{ + acc += input * XXH_PRIME64_2; + acc = XXH_rotl64(acc, 31); + acc *= XXH_PRIME64_1; +#if (defined(__AVX512F__)) && !defined(XXH_ENABLE_AUTOVECTORIZE) + /* + * DISABLE AUTOVECTORIZATION: + * A compiler fence is used to prevent GCC and Clang from + * autovectorizing the XXH64 loop (pragmas and attributes don't work for some + * reason) without globally disabling AVX512. + * + * Autovectorization of XXH64 tends to be detrimental, + * though the exact outcome may change depending on exact cpu and compiler version. + * For information, it has been reported as detrimental for Skylake-X, + * but possibly beneficial for Zen4. + * + * The default is to disable auto-vectorization, + * but you can select to enable it instead using `XXH_ENABLE_AUTOVECTORIZE` build variable. + */ + XXH_COMPILER_GUARD(acc); +#endif + return acc; +} + +static xxh_u64 XXH64_mergeRound(xxh_u64 acc, xxh_u64 val) +{ + val = XXH64_round(0, val); + acc ^= val; + acc = acc * XXH_PRIME64_1 + XXH_PRIME64_4; + return acc; +} + +/*! @copydoc XXH32_avalanche */ +static xxh_u64 XXH64_avalanche(xxh_u64 hash) +{ + hash ^= hash >> 33; + hash *= XXH_PRIME64_2; + hash ^= hash >> 29; + hash *= XXH_PRIME64_3; + hash ^= hash >> 32; + return hash; +} + + +#define XXH_get64bits(p) XXH_readLE64_align(p, align) + +/*! + * @internal + * @brief Processes the last 0-31 bytes of @p ptr. + * + * There may be up to 31 bytes remaining to consume from the input. + * This final stage will digest them to ensure that all input bytes are present + * in the final mix. + * + * @param hash The hash to finalize. + * @param ptr The pointer to the remaining input. + * @param len The remaining length, modulo 32. + * @param align Whether @p ptr is aligned. + * @return The finalized hash + * @see XXH32_finalize(). + */ +static XXH_PUREF xxh_u64 +XXH64_finalize(xxh_u64 hash, const xxh_u8* ptr, size_t len, XXH_alignment align) +{ + if (ptr==NULL) XXH_ASSERT(len == 0); + len &= 31; + while (len >= 8) { + xxh_u64 const k1 = XXH64_round(0, XXH_get64bits(ptr)); + ptr += 8; + hash ^= k1; + hash = XXH_rotl64(hash,27) * XXH_PRIME64_1 + XXH_PRIME64_4; + len -= 8; + } + if (len >= 4) { + hash ^= (xxh_u64)(XXH_get32bits(ptr)) * XXH_PRIME64_1; + ptr += 4; + hash = XXH_rotl64(hash, 23) * XXH_PRIME64_2 + XXH_PRIME64_3; + len -= 4; + } + while (len > 0) { + hash ^= (*ptr++) * XXH_PRIME64_5; + hash = XXH_rotl64(hash, 11) * XXH_PRIME64_1; + --len; + } + return XXH64_avalanche(hash); +} + +#ifdef XXH_OLD_NAMES +# define PROCESS1_64 XXH_PROCESS1_64 +# define PROCESS4_64 XXH_PROCESS4_64 +# define PROCESS8_64 XXH_PROCESS8_64 +#else +# undef XXH_PROCESS1_64 +# undef XXH_PROCESS4_64 +# undef XXH_PROCESS8_64 +#endif + +/*! + * @internal + * @brief The implementation for @ref XXH64(). + * + * @param input , len , seed Directly passed from @ref XXH64(). + * @param align Whether @p input is aligned. + * @return The calculated hash. + */ +XXH_FORCE_INLINE XXH_PUREF xxh_u64 +XXH64_endian_align(const xxh_u8* input, size_t len, xxh_u64 seed, XXH_alignment align) +{ + xxh_u64 h64; + if (input==NULL) XXH_ASSERT(len == 0); + + if (len>=32) { + const xxh_u8* const bEnd = input + len; + const xxh_u8* const limit = bEnd - 31; + xxh_u64 v1 = seed + XXH_PRIME64_1 + XXH_PRIME64_2; + xxh_u64 v2 = seed + XXH_PRIME64_2; + xxh_u64 v3 = seed + 0; + xxh_u64 v4 = seed - XXH_PRIME64_1; + + do { + v1 = XXH64_round(v1, XXH_get64bits(input)); input+=8; + v2 = XXH64_round(v2, XXH_get64bits(input)); input+=8; + v3 = XXH64_round(v3, XXH_get64bits(input)); input+=8; + v4 = XXH64_round(v4, XXH_get64bits(input)); input+=8; + } while (input<limit); + + h64 = XXH_rotl64(v1, 1) + XXH_rotl64(v2, 7) + XXH_rotl64(v3, 12) + XXH_rotl64(v4, 18); + h64 = XXH64_mergeRound(h64, v1); + h64 = XXH64_mergeRound(h64, v2); + h64 = XXH64_mergeRound(h64, v3); + h64 = XXH64_mergeRound(h64, v4); + + } else { + h64 = seed + XXH_PRIME64_5; + } + + h64 += (xxh_u64) len; + + return XXH64_finalize(h64, input, len, align); +} + + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH64_hash_t XXH64 (XXH_NOESCAPE const void* input, size_t len, XXH64_hash_t seed) +{ +#if !defined(XXH_NO_STREAM) && XXH_SIZE_OPT >= 2 + /* Simple version, good for code maintenance, but unfortunately slow for small inputs */ + XXH64_state_t state; + XXH64_reset(&state, seed); + XXH64_update(&state, (const xxh_u8*)input, len); + return XXH64_digest(&state); +#else + if (XXH_FORCE_ALIGN_CHECK) { + if ((((size_t)input) & 7)==0) { /* Input is aligned, let's leverage the speed advantage */ + return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_aligned); + } } + + return XXH64_endian_align((const xxh_u8*)input, len, seed, XXH_unaligned); + +#endif +} + +/******* Hash Streaming *******/ +#ifndef XXH_NO_STREAM +/*! @ingroup XXH64_family*/ +XXH_PUBLIC_API XXH64_state_t* XXH64_createState(void) +{ + return (XXH64_state_t*)XXH_malloc(sizeof(XXH64_state_t)); +} +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH_errorcode XXH64_freeState(XXH64_state_t* statePtr) +{ + XXH_free(statePtr); + return XXH_OK; +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API void XXH64_copyState(XXH_NOESCAPE XXH64_state_t* dstState, const XXH64_state_t* srcState) +{ + XXH_memcpy(dstState, srcState, sizeof(*dstState)); +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH_errorcode XXH64_reset(XXH_NOESCAPE XXH64_state_t* statePtr, XXH64_hash_t seed) +{ + XXH_ASSERT(statePtr != NULL); + memset(statePtr, 0, sizeof(*statePtr)); + statePtr->v[0] = seed + XXH_PRIME64_1 + XXH_PRIME64_2; + statePtr->v[1] = seed + XXH_PRIME64_2; + statePtr->v[2] = seed + 0; + statePtr->v[3] = seed - XXH_PRIME64_1; + return XXH_OK; +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH_errorcode +XXH64_update (XXH_NOESCAPE XXH64_state_t* state, XXH_NOESCAPE const void* input, size_t len) +{ + if (input==NULL) { + XXH_ASSERT(len == 0); + return XXH_OK; + } + + { const xxh_u8* p = (const xxh_u8*)input; + const xxh_u8* const bEnd = p + len; + + state->total_len += len; + + if (state->memsize + len < 32) { /* fill in tmp buffer */ + XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, len); + state->memsize += (xxh_u32)len; + return XXH_OK; + } + + if (state->memsize) { /* tmp buffer is full */ + XXH_memcpy(((xxh_u8*)state->mem64) + state->memsize, input, 32-state->memsize); + state->v[0] = XXH64_round(state->v[0], XXH_readLE64(state->mem64+0)); + state->v[1] = XXH64_round(state->v[1], XXH_readLE64(state->mem64+1)); + state->v[2] = XXH64_round(state->v[2], XXH_readLE64(state->mem64+2)); + state->v[3] = XXH64_round(state->v[3], XXH_readLE64(state->mem64+3)); + p += 32 - state->memsize; + state->memsize = 0; + } + + if (p+32 <= bEnd) { + const xxh_u8* const limit = bEnd - 32; + + do { + state->v[0] = XXH64_round(state->v[0], XXH_readLE64(p)); p+=8; + state->v[1] = XXH64_round(state->v[1], XXH_readLE64(p)); p+=8; + state->v[2] = XXH64_round(state->v[2], XXH_readLE64(p)); p+=8; + state->v[3] = XXH64_round(state->v[3], XXH_readLE64(p)); p+=8; + } while (p<=limit); + + } + + if (p < bEnd) { + XXH_memcpy(state->mem64, p, (size_t)(bEnd-p)); + state->memsize = (unsigned)(bEnd-p); + } + } + + return XXH_OK; +} + + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH64_hash_t XXH64_digest(XXH_NOESCAPE const XXH64_state_t* state) +{ + xxh_u64 h64; + + if (state->total_len >= 32) { + h64 = XXH_rotl64(state->v[0], 1) + XXH_rotl64(state->v[1], 7) + XXH_rotl64(state->v[2], 12) + XXH_rotl64(state->v[3], 18); + h64 = XXH64_mergeRound(h64, state->v[0]); + h64 = XXH64_mergeRound(h64, state->v[1]); + h64 = XXH64_mergeRound(h64, state->v[2]); + h64 = XXH64_mergeRound(h64, state->v[3]); + } else { + h64 = state->v[2] /*seed*/ + XXH_PRIME64_5; + } + + h64 += (xxh_u64) state->total_len; + + return XXH64_finalize(h64, (const xxh_u8*)state->mem64, (size_t)state->total_len, XXH_aligned); +} +#endif /* !XXH_NO_STREAM */ + +/******* Canonical representation *******/ + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API void XXH64_canonicalFromHash(XXH_NOESCAPE XXH64_canonical_t* dst, XXH64_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH64_canonical_t) == sizeof(XXH64_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) hash = XXH_swap64(hash); + XXH_memcpy(dst, &hash, sizeof(*dst)); +} + +/*! @ingroup XXH64_family */ +XXH_PUBLIC_API XXH64_hash_t XXH64_hashFromCanonical(XXH_NOESCAPE const XXH64_canonical_t* src) +{ + return XXH_readBE64(src); +} + +#ifndef XXH_NO_XXH3 + +/* ********************************************************************* +* XXH3 +* New generation hash designed for speed on small keys and vectorization +************************************************************************ */ +/*! + * @} + * @defgroup XXH3_impl XXH3 implementation + * @ingroup impl + * @{ + */ + +/* === Compiler specifics === */ + +#if ((defined(sun) || defined(__sun)) && __cplusplus) /* Solaris includes __STDC_VERSION__ with C++. Tested with GCC 5.5 */ +# define XXH_RESTRICT /* disable */ +#elif defined (__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* >= C99 */ +# define XXH_RESTRICT restrict +#elif (defined (__GNUC__) && ((__GNUC__ > 3) || (__GNUC__ == 3 && __GNUC_MINOR__ >= 1))) \ + || (defined (__clang__)) \ + || (defined (_MSC_VER) && (_MSC_VER >= 1400)) \ + || (defined (__INTEL_COMPILER) && (__INTEL_COMPILER >= 1300)) +/* + * There are a LOT more compilers that recognize __restrict but this + * covers the major ones. + */ +# define XXH_RESTRICT __restrict +#else +# define XXH_RESTRICT /* disable */ +#endif + +#if (defined(__GNUC__) && (__GNUC__ >= 3)) \ + || (defined(__INTEL_COMPILER) && (__INTEL_COMPILER >= 800)) \ + || defined(__clang__) +# define XXH_likely(x) __builtin_expect(x, 1) +# define XXH_unlikely(x) __builtin_expect(x, 0) +#else +# define XXH_likely(x) (x) +# define XXH_unlikely(x) (x) +#endif + +#ifndef XXH_HAS_INCLUDE +# ifdef __has_include +/* + * Not defined as XXH_HAS_INCLUDE(x) (function-like) because + * this causes segfaults in Apple Clang 4.2 (on Mac OS X 10.7 Lion) + */ +# define XXH_HAS_INCLUDE __has_include +# else +# define XXH_HAS_INCLUDE(x) 0 +# endif +#endif + +#if defined(__GNUC__) || defined(__clang__) +# if defined(__ARM_FEATURE_SVE) +# include <arm_sve.h> +# endif +# if defined(__ARM_NEON__) || defined(__ARM_NEON) \ + || (defined(_M_ARM) && _M_ARM >= 7) \ + || defined(_M_ARM64) || defined(_M_ARM64EC) \ + || (defined(__wasm_simd128__) && XXH_HAS_INCLUDE(<arm_neon.h>)) /* WASM SIMD128 via SIMDe */ +# define inline __inline__ /* circumvent a clang bug */ +# include <arm_neon.h> +# undef inline +# elif defined(__AVX2__) +# include <immintrin.h> +# elif defined(__SSE2__) +# include <emmintrin.h> +# endif +#endif + +#if defined(_MSC_VER) +# include <intrin.h> +#endif + +/* + * One goal of XXH3 is to make it fast on both 32-bit and 64-bit, while + * remaining a true 64-bit/128-bit hash function. + * + * This is done by prioritizing a subset of 64-bit operations that can be + * emulated without too many steps on the average 32-bit machine. + * + * For example, these two lines seem similar, and run equally fast on 64-bit: + * + * xxh_u64 x; + * x ^= (x >> 47); // good + * x ^= (x >> 13); // bad + * + * However, to a 32-bit machine, there is a major difference. + * + * x ^= (x >> 47) looks like this: + * + * x.lo ^= (x.hi >> (47 - 32)); + * + * while x ^= (x >> 13) looks like this: + * + * // note: funnel shifts are not usually cheap. + * x.lo ^= (x.lo >> 13) | (x.hi << (32 - 13)); + * x.hi ^= (x.hi >> 13); + * + * The first one is significantly faster than the second, simply because the + * shift is larger than 32. This means: + * - All the bits we need are in the upper 32 bits, so we can ignore the lower + * 32 bits in the shift. + * - The shift result will always fit in the lower 32 bits, and therefore, + * we can ignore the upper 32 bits in the xor. + * + * Thanks to this optimization, XXH3 only requires these features to be efficient: + * + * - Usable unaligned access + * - A 32-bit or 64-bit ALU + * - If 32-bit, a decent ADC instruction + * - A 32 or 64-bit multiply with a 64-bit result + * - For the 128-bit variant, a decent byteswap helps short inputs. + * + * The first two are already required by XXH32, and almost all 32-bit and 64-bit + * platforms which can run XXH32 can run XXH3 efficiently. + * + * Thumb-1, the classic 16-bit only subset of ARM's instruction set, is one + * notable exception. + * + * First of all, Thumb-1 lacks support for the UMULL instruction which + * performs the important long multiply. This means numerous __aeabi_lmul + * calls. + * + * Second of all, the 8 functional registers are just not enough. + * Setup for __aeabi_lmul, byteshift loads, pointers, and all arithmetic need + * Lo registers, and this shuffling results in thousands more MOVs than A32. + * + * A32 and T32 don't have this limitation. They can access all 14 registers, + * do a 32->64 multiply with UMULL, and the flexible operand allowing free + * shifts is helpful, too. + * + * Therefore, we do a quick sanity check. + * + * If compiling Thumb-1 for a target which supports ARM instructions, we will + * emit a warning, as it is not a "sane" platform to compile for. + * + * Usually, if this happens, it is because of an accident and you probably need + * to specify -march, as you likely meant to compile for a newer architecture. + * + * Credit: large sections of the vectorial and asm source code paths + * have been contributed by @easyaspi314 + */ +#if defined(__thumb__) && !defined(__thumb2__) && defined(__ARM_ARCH_ISA_ARM) +# warning "XXH3 is highly inefficient without ARM or Thumb-2." +#endif + +/* ========================================== + * Vectorization detection + * ========================================== */ + +#ifdef XXH_DOXYGEN +/*! + * @ingroup tuning + * @brief Overrides the vectorization implementation chosen for XXH3. + * + * Can be defined to 0 to disable SIMD or any of the values mentioned in + * @ref XXH_VECTOR_TYPE. + * + * If this is not defined, it uses predefined macros to determine the best + * implementation. + */ +# define XXH_VECTOR XXH_SCALAR +/*! + * @ingroup tuning + * @brief Possible values for @ref XXH_VECTOR. + * + * Note that these are actually implemented as macros. + * + * If this is not defined, it is detected automatically. + * internal macro XXH_X86DISPATCH overrides this. + */ +enum XXH_VECTOR_TYPE /* fake enum */ { + XXH_SCALAR = 0, /*!< Portable scalar version */ + XXH_SSE2 = 1, /*!< + * SSE2 for Pentium 4, Opteron, all x86_64. + * + * @note SSE2 is also guaranteed on Windows 10, macOS, and + * Android x86. + */ + XXH_AVX2 = 2, /*!< AVX2 for Haswell and Bulldozer */ + XXH_AVX512 = 3, /*!< AVX512 for Skylake and Icelake */ + XXH_NEON = 4, /*!< + * NEON for most ARMv7-A, all AArch64, and WASM SIMD128 + * via the SIMDeverywhere polyfill provided with the + * Emscripten SDK. + */ + XXH_VSX = 5, /*!< VSX and ZVector for POWER8/z13 (64-bit) */ + XXH_SVE = 6, /*!< SVE for some ARMv8-A and ARMv9-A */ +}; +/*! + * @ingroup tuning + * @brief Selects the minimum alignment for XXH3's accumulators. + * + * When using SIMD, this should match the alignment required for said vector + * type, so, for example, 32 for AVX2. + * + * Default: Auto detected. + */ +# define XXH_ACC_ALIGN 8 +#endif + +/* Actual definition */ +#ifndef XXH_DOXYGEN +# define XXH_SCALAR 0 +# define XXH_SSE2 1 +# define XXH_AVX2 2 +# define XXH_AVX512 3 +# define XXH_NEON 4 +# define XXH_VSX 5 +# define XXH_SVE 6 +#endif + +#ifndef XXH_VECTOR /* can be defined on command line */ +# if defined(__ARM_FEATURE_SVE) +# define XXH_VECTOR XXH_SVE +# elif ( \ + defined(__ARM_NEON__) || defined(__ARM_NEON) /* gcc */ \ + || defined(_M_ARM) || defined(_M_ARM64) || defined(_M_ARM64EC) /* msvc */ \ + || (defined(__wasm_simd128__) && XXH_HAS_INCLUDE(<arm_neon.h>)) /* wasm simd128 via SIMDe */ \ + ) && ( \ + defined(_WIN32) || defined(__LITTLE_ENDIAN__) /* little endian only */ \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) \ + ) +# define XXH_VECTOR XXH_NEON +# elif defined(__AVX512F__) +# define XXH_VECTOR XXH_AVX512 +# elif defined(__AVX2__) +# define XXH_VECTOR XXH_AVX2 +# elif defined(__SSE2__) || defined(_M_AMD64) || defined(_M_X64) || (defined(_M_IX86_FP) && (_M_IX86_FP == 2)) +# define XXH_VECTOR XXH_SSE2 +# elif (defined(__PPC64__) && defined(__POWER8_VECTOR__)) \ + || (defined(__s390x__) && defined(__VEC__)) \ + && defined(__GNUC__) /* TODO: IBM XL */ +# define XXH_VECTOR XXH_VSX +# else +# define XXH_VECTOR XXH_SCALAR +# endif +#endif + +/* __ARM_FEATURE_SVE is only supported by GCC & Clang. */ +#if (XXH_VECTOR == XXH_SVE) && !defined(__ARM_FEATURE_SVE) +# ifdef _MSC_VER +# pragma warning(once : 4606) +# else +# warning "__ARM_FEATURE_SVE isn't supported. Use SCALAR instead." +# endif +# undef XXH_VECTOR +# define XXH_VECTOR XXH_SCALAR +#endif + +/* + * Controls the alignment of the accumulator, + * for compatibility with aligned vector loads, which are usually faster. + */ +#ifndef XXH_ACC_ALIGN +# if defined(XXH_X86DISPATCH) +# define XXH_ACC_ALIGN 64 /* for compatibility with avx512 */ +# elif XXH_VECTOR == XXH_SCALAR /* scalar */ +# define XXH_ACC_ALIGN 8 +# elif XXH_VECTOR == XXH_SSE2 /* sse2 */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_AVX2 /* avx2 */ +# define XXH_ACC_ALIGN 32 +# elif XXH_VECTOR == XXH_NEON /* neon */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_VSX /* vsx */ +# define XXH_ACC_ALIGN 16 +# elif XXH_VECTOR == XXH_AVX512 /* avx512 */ +# define XXH_ACC_ALIGN 64 +# elif XXH_VECTOR == XXH_SVE /* sve */ +# define XXH_ACC_ALIGN 64 +# endif +#endif + +#if defined(XXH_X86DISPATCH) || XXH_VECTOR == XXH_SSE2 \ + || XXH_VECTOR == XXH_AVX2 || XXH_VECTOR == XXH_AVX512 +# define XXH_SEC_ALIGN XXH_ACC_ALIGN +#elif XXH_VECTOR == XXH_SVE +# define XXH_SEC_ALIGN XXH_ACC_ALIGN +#else +# define XXH_SEC_ALIGN 8 +#endif + +#if defined(__GNUC__) || defined(__clang__) +# define XXH_ALIASING __attribute__((__may_alias__)) +#else +# define XXH_ALIASING /* nothing */ +#endif + +/* + * UGLY HACK: + * GCC usually generates the best code with -O3 for xxHash. + * + * However, when targeting AVX2, it is overzealous in its unrolling resulting + * in code roughly 3/4 the speed of Clang. + * + * There are other issues, such as GCC splitting _mm256_loadu_si256 into + * _mm_loadu_si128 + _mm256_inserti128_si256. This is an optimization which + * only applies to Sandy and Ivy Bridge... which don't even support AVX2. + * + * That is why when compiling the AVX2 version, it is recommended to use either + * -O2 -mavx2 -march=haswell + * or + * -O2 -mavx2 -mno-avx256-split-unaligned-load + * for decent performance, or to use Clang instead. + * + * Fortunately, we can control the first one with a pragma that forces GCC into + * -O2, but the other one we can't control without "failed to inline always + * inline function due to target mismatch" warnings. + */ +#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \ + && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__OPTIMIZE__) && XXH_SIZE_OPT <= 0 /* respect -O0 and -Os */ +# pragma GCC push_options +# pragma GCC optimize("-O2") +#endif + +#if XXH_VECTOR == XXH_NEON + +/* + * UGLY HACK: While AArch64 GCC on Linux does not seem to care, on macOS, GCC -O3 + * optimizes out the entire hashLong loop because of the aliasing violation. + * + * However, GCC is also inefficient at load-store optimization with vld1q/vst1q, + * so the only option is to mark it as aliasing. + */ +typedef uint64x2_t xxh_aliasing_uint64x2_t XXH_ALIASING; + +/*! + * @internal + * @brief `vld1q_u64` but faster and alignment-safe. + * + * On AArch64, unaligned access is always safe, but on ARMv7-a, it is only + * *conditionally* safe (`vld1` has an alignment bit like `movdq[ua]` in x86). + * + * GCC for AArch64 sees `vld1q_u8` as an intrinsic instead of a load, so it + * prohibits load-store optimizations. Therefore, a direct dereference is used. + * + * Otherwise, `vld1q_u8` is used with `vreinterpretq_u8_u64` to do a safe + * unaligned load. + */ +#if defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__) +XXH_FORCE_INLINE uint64x2_t XXH_vld1q_u64(void const* ptr) /* silence -Wcast-align */ +{ + return *(xxh_aliasing_uint64x2_t const *)ptr; +} +#else +XXH_FORCE_INLINE uint64x2_t XXH_vld1q_u64(void const* ptr) +{ + return vreinterpretq_u64_u8(vld1q_u8((uint8_t const*)ptr)); +} +#endif + +/*! + * @internal + * @brief `vmlal_u32` on low and high halves of a vector. + * + * This is a workaround for AArch64 GCC < 11 which implemented arm_neon.h with + * inline assembly and were therefore incapable of merging the `vget_{low, high}_u32` + * with `vmlal_u32`. + */ +#if defined(__aarch64__) && defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 11 +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_low_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + /* Inline assembly is the only way */ + __asm__("umlal %0.2d, %1.2s, %2.2s" : "+w" (acc) : "w" (lhs), "w" (rhs)); + return acc; +} +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_high_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + /* This intrinsic works as expected */ + return vmlal_high_u32(acc, lhs, rhs); +} +#else +/* Portable intrinsic versions */ +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_low_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + return vmlal_u32(acc, vget_low_u32(lhs), vget_low_u32(rhs)); +} +/*! @copydoc XXH_vmlal_low_u32 + * Assume the compiler converts this to vmlal_high_u32 on aarch64 */ +XXH_FORCE_INLINE uint64x2_t +XXH_vmlal_high_u32(uint64x2_t acc, uint32x4_t lhs, uint32x4_t rhs) +{ + return vmlal_u32(acc, vget_high_u32(lhs), vget_high_u32(rhs)); +} +#endif + +/*! + * @ingroup tuning + * @brief Controls the NEON to scalar ratio for XXH3 + * + * This can be set to 2, 4, 6, or 8. + * + * ARM Cortex CPUs are _very_ sensitive to how their pipelines are used. + * + * For example, the Cortex-A73 can dispatch 3 micro-ops per cycle, but only 2 of those + * can be NEON. If you are only using NEON instructions, you are only using 2/3 of the CPU + * bandwidth. + * + * This is even more noticeable on the more advanced cores like the Cortex-A76 which + * can dispatch 8 micro-ops per cycle, but still only 2 NEON micro-ops at once. + * + * Therefore, to make the most out of the pipeline, it is beneficial to run 6 NEON lanes + * and 2 scalar lanes, which is chosen by default. + * + * This does not apply to Apple processors or 32-bit processors, which run better with + * full NEON. These will default to 8. Additionally, size-optimized builds run 8 lanes. + * + * This change benefits CPUs with large micro-op buffers without negatively affecting + * most other CPUs: + * + * | Chipset | Dispatch type | NEON only | 6:2 hybrid | Diff. | + * |:----------------------|:--------------------|----------:|-----------:|------:| + * | Snapdragon 730 (A76) | 2 NEON/8 micro-ops | 8.8 GB/s | 10.1 GB/s | ~16% | + * | Snapdragon 835 (A73) | 2 NEON/3 micro-ops | 5.1 GB/s | 5.3 GB/s | ~5% | + * | Marvell PXA1928 (A53) | In-order dual-issue | 1.9 GB/s | 1.9 GB/s | 0% | + * | Apple M1 | 4 NEON/8 micro-ops | 37.3 GB/s | 36.1 GB/s | ~-3% | + * + * It also seems to fix some bad codegen on GCC, making it almost as fast as clang. + * + * When using WASM SIMD128, if this is 2 or 6, SIMDe will scalarize 2 of the lanes meaning + * it effectively becomes worse 4. + * + * @see XXH3_accumulate_512_neon() + */ +# ifndef XXH3_NEON_LANES +# if (defined(__aarch64__) || defined(__arm64__) || defined(_M_ARM64) || defined(_M_ARM64EC)) \ + && !defined(__APPLE__) && XXH_SIZE_OPT <= 0 +# define XXH3_NEON_LANES 6 +# else +# define XXH3_NEON_LANES XXH_ACC_NB +# endif +# endif +#endif /* XXH_VECTOR == XXH_NEON */ + +/* + * VSX and Z Vector helpers. + * + * This is very messy, and any pull requests to clean this up are welcome. + * + * There are a lot of problems with supporting VSX and s390x, due to + * inconsistent intrinsics, spotty coverage, and multiple endiannesses. + */ +#if XXH_VECTOR == XXH_VSX +/* Annoyingly, these headers _may_ define three macros: `bool`, `vector`, + * and `pixel`. This is a problem for obvious reasons. + * + * These keywords are unnecessary; the spec literally says they are + * equivalent to `__bool`, `__vector`, and `__pixel` and may be undef'd + * after including the header. + * + * We use pragma push_macro/pop_macro to keep the namespace clean. */ +# pragma push_macro("bool") +# pragma push_macro("vector") +# pragma push_macro("pixel") +/* silence potential macro redefined warnings */ +# undef bool +# undef vector +# undef pixel + +# if defined(__s390x__) +# include <s390intrin.h> +# else +# include <altivec.h> +# endif + +/* Restore the original macro values, if applicable. */ +# pragma pop_macro("pixel") +# pragma pop_macro("vector") +# pragma pop_macro("bool") + +typedef __vector unsigned long long xxh_u64x2; +typedef __vector unsigned char xxh_u8x16; +typedef __vector unsigned xxh_u32x4; + +/* + * UGLY HACK: Similar to aarch64 macOS GCC, s390x GCC has the same aliasing issue. + */ +typedef xxh_u64x2 xxh_aliasing_u64x2 XXH_ALIASING; + +# ifndef XXH_VSX_BE +# if defined(__BIG_ENDIAN__) \ + || (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) +# define XXH_VSX_BE 1 +# elif defined(__VEC_ELEMENT_REG_ORDER__) && __VEC_ELEMENT_REG_ORDER__ == __ORDER_BIG_ENDIAN__ +# warning "-maltivec=be is not recommended. Please use native endianness." +# define XXH_VSX_BE 1 +# else +# define XXH_VSX_BE 0 +# endif +# endif /* !defined(XXH_VSX_BE) */ + +# if XXH_VSX_BE +# if defined(__POWER9_VECTOR__) || (defined(__clang__) && defined(__s390x__)) +# define XXH_vec_revb vec_revb +# else +/*! + * A polyfill for POWER9's vec_revb(). + */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_revb(xxh_u64x2 val) +{ + xxh_u8x16 const vByteSwap = { 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, 0x00, + 0x0F, 0x0E, 0x0D, 0x0C, 0x0B, 0x0A, 0x09, 0x08 }; + return vec_perm(val, val, vByteSwap); +} +# endif +# endif /* XXH_VSX_BE */ + +/*! + * Performs an unaligned vector load and byte swaps it on big endian. + */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_loadu(const void *ptr) +{ + xxh_u64x2 ret; + XXH_memcpy(&ret, ptr, sizeof(xxh_u64x2)); +# if XXH_VSX_BE + ret = XXH_vec_revb(ret); +# endif + return ret; +} + +/* + * vec_mulo and vec_mule are very problematic intrinsics on PowerPC + * + * These intrinsics weren't added until GCC 8, despite existing for a while, + * and they are endian dependent. Also, their meaning swap depending on version. + * */ +# if defined(__s390x__) + /* s390x is always big endian, no issue on this platform */ +# define XXH_vec_mulo vec_mulo +# define XXH_vec_mule vec_mule +# elif defined(__clang__) && XXH_HAS_BUILTIN(__builtin_altivec_vmuleuw) && !defined(__ibmxl__) +/* Clang has a better way to control this, we can just use the builtin which doesn't swap. */ + /* The IBM XL Compiler (which defined __clang__) only implements the vec_* operations */ +# define XXH_vec_mulo __builtin_altivec_vmulouw +# define XXH_vec_mule __builtin_altivec_vmuleuw +# else +/* gcc needs inline assembly */ +/* Adapted from https://github.com/google/highwayhash/blob/master/highwayhash/hh_vsx.h. */ +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mulo(xxh_u32x4 a, xxh_u32x4 b) +{ + xxh_u64x2 result; + __asm__("vmulouw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b)); + return result; +} +XXH_FORCE_INLINE xxh_u64x2 XXH_vec_mule(xxh_u32x4 a, xxh_u32x4 b) +{ + xxh_u64x2 result; + __asm__("vmuleuw %0, %1, %2" : "=v" (result) : "v" (a), "v" (b)); + return result; +} +# endif /* XXH_vec_mulo, XXH_vec_mule */ +#endif /* XXH_VECTOR == XXH_VSX */ + +#if XXH_VECTOR == XXH_SVE +#define ACCRND(acc, offset) \ +do { \ + svuint64_t input_vec = svld1_u64(mask, xinput + offset); \ + svuint64_t secret_vec = svld1_u64(mask, xsecret + offset); \ + svuint64_t mixed = sveor_u64_x(mask, secret_vec, input_vec); \ + svuint64_t swapped = svtbl_u64(input_vec, kSwap); \ + svuint64_t mixed_lo = svextw_u64_x(mask, mixed); \ + svuint64_t mixed_hi = svlsr_n_u64_x(mask, mixed, 32); \ + svuint64_t mul = svmad_u64_x(mask, mixed_lo, mixed_hi, swapped); \ + acc = svadd_u64_x(mask, acc, mul); \ +} while (0) +#endif /* XXH_VECTOR == XXH_SVE */ + +/* prefetch + * can be disabled, by declaring XXH_NO_PREFETCH build macro */ +#if defined(XXH_NO_PREFETCH) +# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */ +#else +# if XXH_SIZE_OPT >= 1 +# define XXH_PREFETCH(ptr) (void)(ptr) +# elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)) /* _mm_prefetch() not defined outside of x86/x64 */ +# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */ +# define XXH_PREFETCH(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0) +# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) ) +# define XXH_PREFETCH(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */) +# else +# define XXH_PREFETCH(ptr) (void)(ptr) /* disabled */ +# endif +#endif /* XXH_NO_PREFETCH */ + + +/* ========================================== + * XXH3 default settings + * ========================================== */ + +#define XXH_SECRET_DEFAULT_SIZE 192 /* minimum XXH3_SECRET_SIZE_MIN */ + +#if (XXH_SECRET_DEFAULT_SIZE < XXH3_SECRET_SIZE_MIN) +# error "default keyset is not large enough" +#endif + +/*! Pseudorandom secret taken directly from FARSH. */ +XXH_ALIGN(64) static const xxh_u8 XXH3_kSecret[XXH_SECRET_DEFAULT_SIZE] = { + 0xb8, 0xfe, 0x6c, 0x39, 0x23, 0xa4, 0x4b, 0xbe, 0x7c, 0x01, 0x81, 0x2c, 0xf7, 0x21, 0xad, 0x1c, + 0xde, 0xd4, 0x6d, 0xe9, 0x83, 0x90, 0x97, 0xdb, 0x72, 0x40, 0xa4, 0xa4, 0xb7, 0xb3, 0x67, 0x1f, + 0xcb, 0x79, 0xe6, 0x4e, 0xcc, 0xc0, 0xe5, 0x78, 0x82, 0x5a, 0xd0, 0x7d, 0xcc, 0xff, 0x72, 0x21, + 0xb8, 0x08, 0x46, 0x74, 0xf7, 0x43, 0x24, 0x8e, 0xe0, 0x35, 0x90, 0xe6, 0x81, 0x3a, 0x26, 0x4c, + 0x3c, 0x28, 0x52, 0xbb, 0x91, 0xc3, 0x00, 0xcb, 0x88, 0xd0, 0x65, 0x8b, 0x1b, 0x53, 0x2e, 0xa3, + 0x71, 0x64, 0x48, 0x97, 0xa2, 0x0d, 0xf9, 0x4e, 0x38, 0x19, 0xef, 0x46, 0xa9, 0xde, 0xac, 0xd8, + 0xa8, 0xfa, 0x76, 0x3f, 0xe3, 0x9c, 0x34, 0x3f, 0xf9, 0xdc, 0xbb, 0xc7, 0xc7, 0x0b, 0x4f, 0x1d, + 0x8a, 0x51, 0xe0, 0x4b, 0xcd, 0xb4, 0x59, 0x31, 0xc8, 0x9f, 0x7e, 0xc9, 0xd9, 0x78, 0x73, 0x64, + 0xea, 0xc5, 0xac, 0x83, 0x34, 0xd3, 0xeb, 0xc3, 0xc5, 0x81, 0xa0, 0xff, 0xfa, 0x13, 0x63, 0xeb, + 0x17, 0x0d, 0xdd, 0x51, 0xb7, 0xf0, 0xda, 0x49, 0xd3, 0x16, 0x55, 0x26, 0x29, 0xd4, 0x68, 0x9e, + 0x2b, 0x16, 0xbe, 0x58, 0x7d, 0x47, 0xa1, 0xfc, 0x8f, 0xf8, 0xb8, 0xd1, 0x7a, 0xd0, 0x31, 0xce, + 0x45, 0xcb, 0x3a, 0x8f, 0x95, 0x16, 0x04, 0x28, 0xaf, 0xd7, 0xfb, 0xca, 0xbb, 0x4b, 0x40, 0x7e, +}; + +static const xxh_u64 PRIME_MX1 = 0x165667919E3779F9ULL; /*!< 0b0001011001010110011001111001000110011110001101110111100111111001 */ +static const xxh_u64 PRIME_MX2 = 0x9FB21C651E98DF25ULL; /*!< 0b1001111110110010000111000110010100011110100110001101111100100101 */ + +#ifdef XXH_OLD_NAMES +# define kSecret XXH3_kSecret +#endif + +#ifdef XXH_DOXYGEN +/*! + * @brief Calculates a 32-bit to 64-bit long multiply. + * + * Implemented as a macro. + * + * Wraps `__emulu` on MSVC x86 because it tends to call `__allmul` when it doesn't + * need to (but it shouldn't need to anyways, it is about 7 instructions to do + * a 64x64 multiply...). Since we know that this will _always_ emit `MULL`, we + * use that instead of the normal method. + * + * If you are compiling for platforms like Thumb-1 and don't have a better option, + * you may also want to write your own long multiply routine here. + * + * @param x, y Numbers to be multiplied + * @return 64-bit product of the low 32 bits of @p x and @p y. + */ +XXH_FORCE_INLINE xxh_u64 +XXH_mult32to64(xxh_u64 x, xxh_u64 y) +{ + return (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF); +} +#elif defined(_MSC_VER) && defined(_M_IX86) +# define XXH_mult32to64(x, y) __emulu((unsigned)(x), (unsigned)(y)) +#else +/* + * Downcast + upcast is usually better than masking on older compilers like + * GCC 4.2 (especially 32-bit ones), all without affecting newer compilers. + * + * The other method, (x & 0xFFFFFFFF) * (y & 0xFFFFFFFF), will AND both operands + * and perform a full 64x64 multiply -- entirely redundant on 32-bit. + */ +# define XXH_mult32to64(x, y) ((xxh_u64)(xxh_u32)(x) * (xxh_u64)(xxh_u32)(y)) +#endif + +/*! + * @brief Calculates a 64->128-bit long multiply. + * + * Uses `__uint128_t` and `_umul128` if available, otherwise uses a scalar + * version. + * + * @param lhs , rhs The 64-bit integers to be multiplied + * @return The 128-bit result represented in an @ref XXH128_hash_t. + */ +static XXH128_hash_t +XXH_mult64to128(xxh_u64 lhs, xxh_u64 rhs) +{ + /* + * GCC/Clang __uint128_t method. + * + * On most 64-bit targets, GCC and Clang define a __uint128_t type. + * This is usually the best way as it usually uses a native long 64-bit + * multiply, such as MULQ on x86_64 or MUL + UMULH on aarch64. + * + * Usually. + * + * Despite being a 32-bit platform, Clang (and emscripten) define this type + * despite not having the arithmetic for it. This results in a laggy + * compiler builtin call which calculates a full 128-bit multiply. + * In that case it is best to use the portable one. + * https://github.com/Cyan4973/xxHash/issues/211#issuecomment-515575677 + */ +#if (defined(__GNUC__) || defined(__clang__)) && !defined(__wasm__) \ + && defined(__SIZEOF_INT128__) \ + || (defined(_INTEGRAL_MAX_BITS) && _INTEGRAL_MAX_BITS >= 128) + + __uint128_t const product = (__uint128_t)lhs * (__uint128_t)rhs; + XXH128_hash_t r128; + r128.low64 = (xxh_u64)(product); + r128.high64 = (xxh_u64)(product >> 64); + return r128; + + /* + * MSVC for x64's _umul128 method. + * + * xxh_u64 _umul128(xxh_u64 Multiplier, xxh_u64 Multiplicand, xxh_u64 *HighProduct); + * + * This compiles to single operand MUL on x64. + */ +#elif (defined(_M_X64) || defined(_M_IA64)) && !defined(_M_ARM64EC) + +#ifndef _MSC_VER +# pragma intrinsic(_umul128) +#endif + xxh_u64 product_high; + xxh_u64 const product_low = _umul128(lhs, rhs, &product_high); + XXH128_hash_t r128; + r128.low64 = product_low; + r128.high64 = product_high; + return r128; + + /* + * MSVC for ARM64's __umulh method. + * + * This compiles to the same MUL + UMULH as GCC/Clang's __uint128_t method. + */ +#elif defined(_M_ARM64) || defined(_M_ARM64EC) + +#ifndef _MSC_VER +# pragma intrinsic(__umulh) +#endif + XXH128_hash_t r128; + r128.low64 = lhs * rhs; + r128.high64 = __umulh(lhs, rhs); + return r128; + +#else + /* + * Portable scalar method. Optimized for 32-bit and 64-bit ALUs. + * + * This is a fast and simple grade school multiply, which is shown below + * with base 10 arithmetic instead of base 0x100000000. + * + * 9 3 // D2 lhs = 93 + * x 7 5 // D2 rhs = 75 + * ---------- + * 1 5 // D2 lo_lo = (93 % 10) * (75 % 10) = 15 + * 4 5 | // D2 hi_lo = (93 / 10) * (75 % 10) = 45 + * 2 1 | // D2 lo_hi = (93 % 10) * (75 / 10) = 21 + * + 6 3 | | // D2 hi_hi = (93 / 10) * (75 / 10) = 63 + * --------- + * 2 7 | // D2 cross = (15 / 10) + (45 % 10) + 21 = 27 + * + 6 7 | | // D2 upper = (27 / 10) + (45 / 10) + 63 = 67 + * --------- + * 6 9 7 5 // D4 res = (27 * 10) + (15 % 10) + (67 * 100) = 6975 + * + * The reasons for adding the products like this are: + * 1. It avoids manual carry tracking. Just like how + * (9 * 9) + 9 + 9 = 99, the same applies with this for UINT64_MAX. + * This avoids a lot of complexity. + * + * 2. It hints for, and on Clang, compiles to, the powerful UMAAL + * instruction available in ARM's Digital Signal Processing extension + * in 32-bit ARMv6 and later, which is shown below: + * + * void UMAAL(xxh_u32 *RdLo, xxh_u32 *RdHi, xxh_u32 Rn, xxh_u32 Rm) + * { + * xxh_u64 product = (xxh_u64)*RdLo * (xxh_u64)*RdHi + Rn + Rm; + * *RdLo = (xxh_u32)(product & 0xFFFFFFFF); + * *RdHi = (xxh_u32)(product >> 32); + * } + * + * This instruction was designed for efficient long multiplication, and + * allows this to be calculated in only 4 instructions at speeds + * comparable to some 64-bit ALUs. + * + * 3. It isn't terrible on other platforms. Usually this will be a couple + * of 32-bit ADD/ADCs. + */ + + /* First calculate all of the cross products. */ + xxh_u64 const lo_lo = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs & 0xFFFFFFFF); + xxh_u64 const hi_lo = XXH_mult32to64(lhs >> 32, rhs & 0xFFFFFFFF); + xxh_u64 const lo_hi = XXH_mult32to64(lhs & 0xFFFFFFFF, rhs >> 32); + xxh_u64 const hi_hi = XXH_mult32to64(lhs >> 32, rhs >> 32); + + /* Now add the products together. These will never overflow. */ + xxh_u64 const cross = (lo_lo >> 32) + (hi_lo & 0xFFFFFFFF) + lo_hi; + xxh_u64 const upper = (hi_lo >> 32) + (cross >> 32) + hi_hi; + xxh_u64 const lower = (cross << 32) | (lo_lo & 0xFFFFFFFF); + + XXH128_hash_t r128; + r128.low64 = lower; + r128.high64 = upper; + return r128; +#endif +} + +/*! + * @brief Calculates a 64-bit to 128-bit multiply, then XOR folds it. + * + * The reason for the separate function is to prevent passing too many structs + * around by value. This will hopefully inline the multiply, but we don't force it. + * + * @param lhs , rhs The 64-bit integers to multiply + * @return The low 64 bits of the product XOR'd by the high 64 bits. + * @see XXH_mult64to128() + */ +static xxh_u64 +XXH3_mul128_fold64(xxh_u64 lhs, xxh_u64 rhs) +{ + XXH128_hash_t product = XXH_mult64to128(lhs, rhs); + return product.low64 ^ product.high64; +} + +/*! Seems to produce slightly better code on GCC for some reason. */ +XXH_FORCE_INLINE XXH_CONSTF xxh_u64 XXH_xorshift64(xxh_u64 v64, int shift) +{ + XXH_ASSERT(0 <= shift && shift < 64); + return v64 ^ (v64 >> shift); +} + +/* + * This is a fast avalanche stage, + * suitable when input bits are already partially mixed + */ +static XXH64_hash_t XXH3_avalanche(xxh_u64 h64) +{ + h64 = XXH_xorshift64(h64, 37); + h64 *= PRIME_MX1; + h64 = XXH_xorshift64(h64, 32); + return h64; +} + +/* + * This is a stronger avalanche, + * inspired by Pelle Evensen's rrmxmx + * preferable when input has not been previously mixed + */ +static XXH64_hash_t XXH3_rrmxmx(xxh_u64 h64, xxh_u64 len) +{ + /* this mix is inspired by Pelle Evensen's rrmxmx */ + h64 ^= XXH_rotl64(h64, 49) ^ XXH_rotl64(h64, 24); + h64 *= PRIME_MX2; + h64 ^= (h64 >> 35) + len ; + h64 *= PRIME_MX2; + return XXH_xorshift64(h64, 28); +} + + +/* ========================================== + * Short keys + * ========================================== + * One of the shortcomings of XXH32 and XXH64 was that their performance was + * sub-optimal on short lengths. It used an iterative algorithm which strongly + * favored lengths that were a multiple of 4 or 8. + * + * Instead of iterating over individual inputs, we use a set of single shot + * functions which piece together a range of lengths and operate in constant time. + * + * Additionally, the number of multiplies has been significantly reduced. This + * reduces latency, especially when emulating 64-bit multiplies on 32-bit. + * + * Depending on the platform, this may or may not be faster than XXH32, but it + * is almost guaranteed to be faster than XXH64. + */ + +/* + * At very short lengths, there isn't enough input to fully hide secrets, or use + * the entire secret. + * + * There is also only a limited amount of mixing we can do before significantly + * impacting performance. + * + * Therefore, we use different sections of the secret and always mix two secret + * samples with an XOR. This should have no effect on performance on the + * seedless or withSeed variants because everything _should_ be constant folded + * by modern compilers. + * + * The XOR mixing hides individual parts of the secret and increases entropy. + * + * This adds an extra layer of strength for custom secrets. + */ +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_1to3_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(1 <= len && len <= 3); + XXH_ASSERT(secret != NULL); + /* + * len = 1: combined = { input[0], 0x01, input[0], input[0] } + * len = 2: combined = { input[1], 0x02, input[0], input[1] } + * len = 3: combined = { input[2], 0x03, input[0], input[1] } + */ + { xxh_u8 const c1 = input[0]; + xxh_u8 const c2 = input[len >> 1]; + xxh_u8 const c3 = input[len - 1]; + xxh_u32 const combined = ((xxh_u32)c1 << 16) | ((xxh_u32)c2 << 24) + | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8); + xxh_u64 const bitflip = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed; + xxh_u64 const keyed = (xxh_u64)combined ^ bitflip; + return XXH64_avalanche(keyed); + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_4to8_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(4 <= len && len <= 8); + seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32; + { xxh_u32 const input1 = XXH_readLE32(input); + xxh_u32 const input2 = XXH_readLE32(input + len - 4); + xxh_u64 const bitflip = (XXH_readLE64(secret+8) ^ XXH_readLE64(secret+16)) - seed; + xxh_u64 const input64 = input2 + (((xxh_u64)input1) << 32); + xxh_u64 const keyed = input64 ^ bitflip; + return XXH3_rrmxmx(keyed, len); + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_9to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(9 <= len && len <= 16); + { xxh_u64 const bitflip1 = (XXH_readLE64(secret+24) ^ XXH_readLE64(secret+32)) + seed; + xxh_u64 const bitflip2 = (XXH_readLE64(secret+40) ^ XXH_readLE64(secret+48)) - seed; + xxh_u64 const input_lo = XXH_readLE64(input) ^ bitflip1; + xxh_u64 const input_hi = XXH_readLE64(input + len - 8) ^ bitflip2; + xxh_u64 const acc = len + + XXH_swap64(input_lo) + input_hi + + XXH3_mul128_fold64(input_lo, input_hi); + return XXH3_avalanche(acc); + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_0to16_64b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(len <= 16); + { if (XXH_likely(len > 8)) return XXH3_len_9to16_64b(input, len, secret, seed); + if (XXH_likely(len >= 4)) return XXH3_len_4to8_64b(input, len, secret, seed); + if (len) return XXH3_len_1to3_64b(input, len, secret, seed); + return XXH64_avalanche(seed ^ (XXH_readLE64(secret+56) ^ XXH_readLE64(secret+64))); + } +} + +/* + * DISCLAIMER: There are known *seed-dependent* multicollisions here due to + * multiplication by zero, affecting hashes of lengths 17 to 240. + * + * However, they are very unlikely. + * + * Keep this in mind when using the unseeded XXH3_64bits() variant: As with all + * unseeded non-cryptographic hashes, it does not attempt to defend itself + * against specially crafted inputs, only random inputs. + * + * Compared to classic UMAC where a 1 in 2^31 chance of 4 consecutive bytes + * cancelling out the secret is taken an arbitrary number of times (addressed + * in XXH3_accumulate_512), this collision is very unlikely with random inputs + * and/or proper seeding: + * + * This only has a 1 in 2^63 chance of 8 consecutive bytes cancelling out, in a + * function that is only called up to 16 times per hash with up to 240 bytes of + * input. + * + * This is not too bad for a non-cryptographic hash function, especially with + * only 64 bit outputs. + * + * The 128-bit variant (which trades some speed for strength) is NOT affected + * by this, although it is always a good idea to use a proper seed if you care + * about strength. + */ +XXH_FORCE_INLINE xxh_u64 XXH3_mix16B(const xxh_u8* XXH_RESTRICT input, + const xxh_u8* XXH_RESTRICT secret, xxh_u64 seed64) +{ +#if defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__i386__) && defined(__SSE2__) /* x86 + SSE2 */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable like XXH32 hack */ + /* + * UGLY HACK: + * GCC for x86 tends to autovectorize the 128-bit multiply, resulting in + * slower code. + * + * By forcing seed64 into a register, we disrupt the cost model and + * cause it to scalarize. See `XXH32_round()` + * + * FIXME: Clang's output is still _much_ faster -- On an AMD Ryzen 3600, + * XXH3_64bits @ len=240 runs at 4.6 GB/s with Clang 9, but 3.3 GB/s on + * GCC 9.2, despite both emitting scalar code. + * + * GCC generates much better scalar code than Clang for the rest of XXH3, + * which is why finding a more optimal codepath is an interest. + */ + XXH_COMPILER_GUARD(seed64); +#endif + { xxh_u64 const input_lo = XXH_readLE64(input); + xxh_u64 const input_hi = XXH_readLE64(input+8); + return XXH3_mul128_fold64( + input_lo ^ (XXH_readLE64(secret) + seed64), + input_hi ^ (XXH_readLE64(secret+8) - seed64) + ); + } +} + +/* For mid range keys, XXH3 uses a Mum-hash variant. */ +XXH_FORCE_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_17to128_64b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(16 < len && len <= 128); + + { xxh_u64 acc = len * XXH_PRIME64_1; +#if XXH_SIZE_OPT >= 1 + /* Smaller and cleaner, but slightly slower. */ + unsigned int i = (unsigned int)(len - 1) / 32; + do { + acc += XXH3_mix16B(input+16 * i, secret+32*i, seed); + acc += XXH3_mix16B(input+len-16*(i+1), secret+32*i+16, seed); + } while (i-- != 0); +#else + if (len > 32) { + if (len > 64) { + if (len > 96) { + acc += XXH3_mix16B(input+48, secret+96, seed); + acc += XXH3_mix16B(input+len-64, secret+112, seed); + } + acc += XXH3_mix16B(input+32, secret+64, seed); + acc += XXH3_mix16B(input+len-48, secret+80, seed); + } + acc += XXH3_mix16B(input+16, secret+32, seed); + acc += XXH3_mix16B(input+len-32, secret+48, seed); + } + acc += XXH3_mix16B(input+0, secret+0, seed); + acc += XXH3_mix16B(input+len-16, secret+16, seed); +#endif + return XXH3_avalanche(acc); + } +} + +XXH_NO_INLINE XXH_PUREF XXH64_hash_t +XXH3_len_129to240_64b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + + #define XXH3_MIDSIZE_STARTOFFSET 3 + #define XXH3_MIDSIZE_LASTOFFSET 17 + + { xxh_u64 acc = len * XXH_PRIME64_1; + xxh_u64 acc_end; + unsigned int const nbRounds = (unsigned int)len / 16; + unsigned int i; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + for (i=0; i<8; i++) { + acc += XXH3_mix16B(input+(16*i), secret+(16*i), seed); + } + /* last bytes */ + acc_end = XXH3_mix16B(input + len - 16, secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET, seed); + XXH_ASSERT(nbRounds >= 8); + acc = XXH3_avalanche(acc); +#if defined(__clang__) /* Clang */ \ + && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */ + /* + * UGLY HACK: + * Clang for ARMv7-A tries to vectorize this loop, similar to GCC x86. + * In everywhere else, it uses scalar code. + * + * For 64->128-bit multiplies, even if the NEON was 100% optimal, it + * would still be slower than UMAAL (see XXH_mult64to128). + * + * Unfortunately, Clang doesn't handle the long multiplies properly and + * converts them to the nonexistent "vmulq_u64" intrinsic, which is then + * scalarized into an ugly mess of VMOV.32 instructions. + * + * This mess is difficult to avoid without turning autovectorization + * off completely, but they are usually relatively minor and/or not + * worth it to fix. + * + * This loop is the easiest to fix, as unlike XXH32, this pragma + * _actually works_ because it is a loop vectorization instead of an + * SLP vectorization. + */ + #pragma clang loop vectorize(disable) +#endif + for (i=8 ; i < nbRounds; i++) { + /* + * Prevents clang for unrolling the acc loop and interleaving with this one. + */ + XXH_COMPILER_GUARD(acc); + acc_end += XXH3_mix16B(input+(16*i), secret+(16*(i-8)) + XXH3_MIDSIZE_STARTOFFSET, seed); + } + return XXH3_avalanche(acc + acc_end); + } +} + + +/* ======= Long Keys ======= */ + +#define XXH_STRIPE_LEN 64 +#define XXH_SECRET_CONSUME_RATE 8 /* nb of secret bytes consumed at each accumulation */ +#define XXH_ACC_NB (XXH_STRIPE_LEN / sizeof(xxh_u64)) + +#ifdef XXH_OLD_NAMES +# define STRIPE_LEN XXH_STRIPE_LEN +# define ACC_NB XXH_ACC_NB +#endif + +#ifndef XXH_PREFETCH_DIST +# ifdef __clang__ +# define XXH_PREFETCH_DIST 320 +# else +# if (XXH_VECTOR == XXH_AVX512) +# define XXH_PREFETCH_DIST 512 +# else +# define XXH_PREFETCH_DIST 384 +# endif +# endif /* __clang__ */ +#endif /* XXH_PREFETCH_DIST */ + +/* + * These macros are to generate an XXH3_accumulate() function. + * The two arguments select the name suffix and target attribute. + * + * The name of this symbol is XXH3_accumulate_<name>() and it calls + * XXH3_accumulate_512_<name>(). + * + * It may be useful to hand implement this function if the compiler fails to + * optimize the inline function. + */ +#define XXH3_ACCUMULATE_TEMPLATE(name) \ +void \ +XXH3_accumulate_##name(xxh_u64* XXH_RESTRICT acc, \ + const xxh_u8* XXH_RESTRICT input, \ + const xxh_u8* XXH_RESTRICT secret, \ + size_t nbStripes) \ +{ \ + size_t n; \ + for (n = 0; n < nbStripes; n++ ) { \ + const xxh_u8* const in = input + n*XXH_STRIPE_LEN; \ + XXH_PREFETCH(in + XXH_PREFETCH_DIST); \ + XXH3_accumulate_512_##name( \ + acc, \ + in, \ + secret + n*XXH_SECRET_CONSUME_RATE); \ + } \ +} + + +XXH_FORCE_INLINE void XXH_writeLE64(void* dst, xxh_u64 v64) +{ + if (!XXH_CPU_LITTLE_ENDIAN) v64 = XXH_swap64(v64); + XXH_memcpy(dst, &v64, sizeof(v64)); +} + +/* Several intrinsic functions below are supposed to accept __int64 as argument, + * as documented in https://software.intel.com/sites/landingpage/IntrinsicsGuide/ . + * However, several environments do not define __int64 type, + * requiring a workaround. + */ +#if !defined (__VMS) \ + && (defined (__cplusplus) \ + || (defined (__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) /* C99 */) ) + typedef int64_t xxh_i64; +#else + /* the following type must have a width of 64-bit */ + typedef long long xxh_i64; +#endif + + +/* + * XXH3_accumulate_512 is the tightest loop for long inputs, and it is the most optimized. + * + * It is a hardened version of UMAC, based off of FARSH's implementation. + * + * This was chosen because it adapts quite well to 32-bit, 64-bit, and SIMD + * implementations, and it is ridiculously fast. + * + * We harden it by mixing the original input to the accumulators as well as the product. + * + * This means that in the (relatively likely) case of a multiply by zero, the + * original input is preserved. + * + * On 128-bit inputs, we swap 64-bit pairs when we add the input to improve + * cross-pollination, as otherwise the upper and lower halves would be + * essentially independent. + * + * This doesn't matter on 64-bit hashes since they all get merged together in + * the end, so we skip the extra step. + * + * Both XXH3_64bits and XXH3_128bits use this subroutine. + */ + +#if (XXH_VECTOR == XXH_AVX512) \ + || (defined(XXH_DISPATCH_AVX512) && XXH_DISPATCH_AVX512 != 0) + +#ifndef XXH_TARGET_AVX512 +# define XXH_TARGET_AVX512 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_accumulate_512_avx512(void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + __m512i* const xacc = (__m512i *) acc; + XXH_ASSERT((((size_t)acc) & 63) == 0); + XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i)); + + { + /* data_vec = input[0]; */ + __m512i const data_vec = _mm512_loadu_si512 (input); + /* key_vec = secret[0]; */ + __m512i const key_vec = _mm512_loadu_si512 (secret); + /* data_key = data_vec ^ key_vec; */ + __m512i const data_key = _mm512_xor_si512 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m512i const data_key_lo = _mm512_srli_epi64 (data_key, 32); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m512i const product = _mm512_mul_epu32 (data_key, data_key_lo); + /* xacc[0] += swap(data_vec); */ + __m512i const data_swap = _mm512_shuffle_epi32(data_vec, (_MM_PERM_ENUM)_MM_SHUFFLE(1, 0, 3, 2)); + __m512i const sum = _mm512_add_epi64(*xacc, data_swap); + /* xacc[0] += product; */ + *xacc = _mm512_add_epi64(product, sum); + } +} +XXH_FORCE_INLINE XXH_TARGET_AVX512 XXH3_ACCUMULATE_TEMPLATE(avx512) + +/* + * XXH3_scrambleAcc: Scrambles the accumulators to improve mixing. + * + * Multiplication isn't perfect, as explained by Google in HighwayHash: + * + * // Multiplication mixes/scrambles bytes 0-7 of the 64-bit result to + * // varying degrees. In descending order of goodness, bytes + * // 3 4 2 5 1 6 0 7 have quality 228 224 164 160 100 96 36 32. + * // As expected, the upper and lower bytes are much worse. + * + * Source: https://github.com/google/highwayhash/blob/0aaf66b/highwayhash/hh_avx2.h#L291 + * + * Since our algorithm uses a pseudorandom secret to add some variance into the + * mix, we don't need to (or want to) mix as often or as much as HighwayHash does. + * + * This isn't as tight as XXH3_accumulate, but still written in SIMD to avoid + * extraction. + * + * Both XXH3_64bits and XXH3_128bits use this subroutine. + */ + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_scrambleAcc_avx512(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 63) == 0); + XXH_STATIC_ASSERT(XXH_STRIPE_LEN == sizeof(__m512i)); + { __m512i* const xacc = (__m512i*) acc; + const __m512i prime32 = _mm512_set1_epi32((int)XXH_PRIME32_1); + + /* xacc[0] ^= (xacc[0] >> 47) */ + __m512i const acc_vec = *xacc; + __m512i const shifted = _mm512_srli_epi64 (acc_vec, 47); + /* xacc[0] ^= secret; */ + __m512i const key_vec = _mm512_loadu_si512 (secret); + __m512i const data_key = _mm512_ternarylogic_epi32(key_vec, acc_vec, shifted, 0x96 /* key_vec ^ acc_vec ^ shifted */); + + /* xacc[0] *= XXH_PRIME32_1; */ + __m512i const data_key_hi = _mm512_srli_epi64 (data_key, 32); + __m512i const prod_lo = _mm512_mul_epu32 (data_key, prime32); + __m512i const prod_hi = _mm512_mul_epu32 (data_key_hi, prime32); + *xacc = _mm512_add_epi64(prod_lo, _mm512_slli_epi64(prod_hi, 32)); + } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX512 void +XXH3_initCustomSecret_avx512(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 63) == 0); + XXH_STATIC_ASSERT(XXH_SEC_ALIGN == 64); + XXH_ASSERT(((size_t)customSecret & 63) == 0); + (void)(&XXH_writeLE64); + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m512i); + __m512i const seed_pos = _mm512_set1_epi64((xxh_i64)seed64); + __m512i const seed = _mm512_mask_sub_epi64(seed_pos, 0xAA, _mm512_set1_epi8(0), seed_pos); + + const __m512i* const src = (const __m512i*) ((const void*) XXH3_kSecret); + __m512i* const dest = ( __m512i*) customSecret; + int i; + XXH_ASSERT(((size_t)src & 63) == 0); /* control alignment */ + XXH_ASSERT(((size_t)dest & 63) == 0); + for (i=0; i < nbRounds; ++i) { + dest[i] = _mm512_add_epi64(_mm512_load_si512(src + i), seed); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_AVX2) \ + || (defined(XXH_DISPATCH_AVX2) && XXH_DISPATCH_AVX2 != 0) + +#ifndef XXH_TARGET_AVX2 +# define XXH_TARGET_AVX2 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void +XXH3_accumulate_512_avx2( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 31) == 0); + { __m256i* const xacc = (__m256i *) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xinput = (const __m256i *) input; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xsecret = (const __m256i *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) { + /* data_vec = xinput[i]; */ + __m256i const data_vec = _mm256_loadu_si256 (xinput+i); + /* key_vec = xsecret[i]; */ + __m256i const key_vec = _mm256_loadu_si256 (xsecret+i); + /* data_key = data_vec ^ key_vec; */ + __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m256i const data_key_lo = _mm256_srli_epi64 (data_key, 32); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m256i const product = _mm256_mul_epu32 (data_key, data_key_lo); + /* xacc[i] += swap(data_vec); */ + __m256i const data_swap = _mm256_shuffle_epi32(data_vec, _MM_SHUFFLE(1, 0, 3, 2)); + __m256i const sum = _mm256_add_epi64(xacc[i], data_swap); + /* xacc[i] += product; */ + xacc[i] = _mm256_add_epi64(product, sum); + } } +} +XXH_FORCE_INLINE XXH_TARGET_AVX2 XXH3_ACCUMULATE_TEMPLATE(avx2) + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void +XXH3_scrambleAcc_avx2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 31) == 0); + { __m256i* const xacc = (__m256i*) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm256_loadu_si256 requires a const __m256i * pointer for some reason. */ + const __m256i* const xsecret = (const __m256i *) secret; + const __m256i prime32 = _mm256_set1_epi32((int)XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m256i); i++) { + /* xacc[i] ^= (xacc[i] >> 47) */ + __m256i const acc_vec = xacc[i]; + __m256i const shifted = _mm256_srli_epi64 (acc_vec, 47); + __m256i const data_vec = _mm256_xor_si256 (acc_vec, shifted); + /* xacc[i] ^= xsecret; */ + __m256i const key_vec = _mm256_loadu_si256 (xsecret+i); + __m256i const data_key = _mm256_xor_si256 (data_vec, key_vec); + + /* xacc[i] *= XXH_PRIME32_1; */ + __m256i const data_key_hi = _mm256_srli_epi64 (data_key, 32); + __m256i const prod_lo = _mm256_mul_epu32 (data_key, prime32); + __m256i const prod_hi = _mm256_mul_epu32 (data_key_hi, prime32); + xacc[i] = _mm256_add_epi64(prod_lo, _mm256_slli_epi64(prod_hi, 32)); + } + } +} + +XXH_FORCE_INLINE XXH_TARGET_AVX2 void XXH3_initCustomSecret_avx2(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 31) == 0); + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE / sizeof(__m256i)) == 6); + XXH_STATIC_ASSERT(XXH_SEC_ALIGN <= 64); + (void)(&XXH_writeLE64); + XXH_PREFETCH(customSecret); + { __m256i const seed = _mm256_set_epi64x((xxh_i64)(0U - seed64), (xxh_i64)seed64, (xxh_i64)(0U - seed64), (xxh_i64)seed64); + + const __m256i* const src = (const __m256i*) ((const void*) XXH3_kSecret); + __m256i* dest = ( __m256i*) customSecret; + +# if defined(__GNUC__) || defined(__clang__) + /* + * On GCC & Clang, marking 'dest' as modified will cause the compiler: + * - do not extract the secret from sse registers in the internal loop + * - use less common registers, and avoid pushing these reg into stack + */ + XXH_COMPILER_GUARD(dest); +# endif + XXH_ASSERT(((size_t)src & 31) == 0); /* control alignment */ + XXH_ASSERT(((size_t)dest & 31) == 0); + + /* GCC -O2 need unroll loop manually */ + dest[0] = _mm256_add_epi64(_mm256_load_si256(src+0), seed); + dest[1] = _mm256_add_epi64(_mm256_load_si256(src+1), seed); + dest[2] = _mm256_add_epi64(_mm256_load_si256(src+2), seed); + dest[3] = _mm256_add_epi64(_mm256_load_si256(src+3), seed); + dest[4] = _mm256_add_epi64(_mm256_load_si256(src+4), seed); + dest[5] = _mm256_add_epi64(_mm256_load_si256(src+5), seed); + } +} + +#endif + +/* x86dispatch always generates SSE2 */ +#if (XXH_VECTOR == XXH_SSE2) || defined(XXH_X86DISPATCH) + +#ifndef XXH_TARGET_SSE2 +# define XXH_TARGET_SSE2 /* disable attribute target */ +#endif + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void +XXH3_accumulate_512_sse2( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + /* SSE2 is just a half-scale version of the AVX2 version. */ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { __m128i* const xacc = (__m128i *) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xinput = (const __m128i *) input; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xsecret = (const __m128i *) secret; + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) { + /* data_vec = xinput[i]; */ + __m128i const data_vec = _mm_loadu_si128 (xinput+i); + /* key_vec = xsecret[i]; */ + __m128i const key_vec = _mm_loadu_si128 (xsecret+i); + /* data_key = data_vec ^ key_vec; */ + __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); + /* data_key_lo = data_key >> 32; */ + __m128i const data_key_lo = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + /* product = (data_key & 0xffffffff) * (data_key_lo & 0xffffffff); */ + __m128i const product = _mm_mul_epu32 (data_key, data_key_lo); + /* xacc[i] += swap(data_vec); */ + __m128i const data_swap = _mm_shuffle_epi32(data_vec, _MM_SHUFFLE(1,0,3,2)); + __m128i const sum = _mm_add_epi64(xacc[i], data_swap); + /* xacc[i] += product; */ + xacc[i] = _mm_add_epi64(product, sum); + } } +} +XXH_FORCE_INLINE XXH_TARGET_SSE2 XXH3_ACCUMULATE_TEMPLATE(sse2) + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void +XXH3_scrambleAcc_sse2(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + { __m128i* const xacc = (__m128i*) acc; + /* Unaligned. This is mainly for pointer arithmetic, and because + * _mm_loadu_si128 requires a const __m128i * pointer for some reason. */ + const __m128i* const xsecret = (const __m128i *) secret; + const __m128i prime32 = _mm_set1_epi32((int)XXH_PRIME32_1); + + size_t i; + for (i=0; i < XXH_STRIPE_LEN/sizeof(__m128i); i++) { + /* xacc[i] ^= (xacc[i] >> 47) */ + __m128i const acc_vec = xacc[i]; + __m128i const shifted = _mm_srli_epi64 (acc_vec, 47); + __m128i const data_vec = _mm_xor_si128 (acc_vec, shifted); + /* xacc[i] ^= xsecret[i]; */ + __m128i const key_vec = _mm_loadu_si128 (xsecret+i); + __m128i const data_key = _mm_xor_si128 (data_vec, key_vec); + + /* xacc[i] *= XXH_PRIME32_1; */ + __m128i const data_key_hi = _mm_shuffle_epi32 (data_key, _MM_SHUFFLE(0, 3, 0, 1)); + __m128i const prod_lo = _mm_mul_epu32 (data_key, prime32); + __m128i const prod_hi = _mm_mul_epu32 (data_key_hi, prime32); + xacc[i] = _mm_add_epi64(prod_lo, _mm_slli_epi64(prod_hi, 32)); + } + } +} + +XXH_FORCE_INLINE XXH_TARGET_SSE2 void XXH3_initCustomSecret_sse2(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0); + (void)(&XXH_writeLE64); + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / sizeof(__m128i); + +# if defined(_MSC_VER) && defined(_M_IX86) && _MSC_VER < 1900 + /* MSVC 32bit mode does not support _mm_set_epi64x before 2015 */ + XXH_ALIGN(16) const xxh_i64 seed64x2[2] = { (xxh_i64)seed64, (xxh_i64)(0U - seed64) }; + __m128i const seed = _mm_load_si128((__m128i const*)seed64x2); +# else + __m128i const seed = _mm_set_epi64x((xxh_i64)(0U - seed64), (xxh_i64)seed64); +# endif + int i; + + const void* const src16 = XXH3_kSecret; + __m128i* dst16 = (__m128i*) customSecret; +# if defined(__GNUC__) || defined(__clang__) + /* + * On GCC & Clang, marking 'dest' as modified will cause the compiler: + * - do not extract the secret from sse registers in the internal loop + * - use less common registers, and avoid pushing these reg into stack + */ + XXH_COMPILER_GUARD(dst16); +# endif + XXH_ASSERT(((size_t)src16 & 15) == 0); /* control alignment */ + XXH_ASSERT(((size_t)dst16 & 15) == 0); + + for (i=0; i < nbRounds; ++i) { + dst16[i] = _mm_add_epi64(_mm_load_si128((const __m128i *)src16+i), seed); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_NEON) + +/* forward declarations for the scalar routines */ +XXH_FORCE_INLINE void +XXH3_scalarRound(void* XXH_RESTRICT acc, void const* XXH_RESTRICT input, + void const* XXH_RESTRICT secret, size_t lane); + +XXH_FORCE_INLINE void +XXH3_scalarScrambleRound(void* XXH_RESTRICT acc, + void const* XXH_RESTRICT secret, size_t lane); + +/*! + * @internal + * @brief The bulk processing loop for NEON and WASM SIMD128. + * + * The NEON code path is actually partially scalar when running on AArch64. This + * is to optimize the pipelining and can have up to 15% speedup depending on the + * CPU, and it also mitigates some GCC codegen issues. + * + * @see XXH3_NEON_LANES for configuring this and details about this optimization. + * + * NEON's 32-bit to 64-bit long multiply takes a half vector of 32-bit + * integers instead of the other platforms which mask full 64-bit vectors, + * so the setup is more complicated than just shifting right. + * + * Additionally, there is an optimization for 4 lanes at once noted below. + * + * Since, as stated, the most optimal amount of lanes for Cortexes is 6, + * there needs to be *three* versions of the accumulate operation used + * for the remaining 2 lanes. + * + * WASM's SIMD128 uses SIMDe's arm_neon.h polyfill because the intrinsics overlap + * nearly perfectly. + */ + +XXH_FORCE_INLINE void +XXH3_accumulate_512_neon( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + XXH_STATIC_ASSERT(XXH3_NEON_LANES > 0 && XXH3_NEON_LANES <= XXH_ACC_NB && XXH3_NEON_LANES % 2 == 0); + { /* GCC for darwin arm64 does not like aliasing here */ + xxh_aliasing_uint64x2_t* const xacc = (xxh_aliasing_uint64x2_t*) acc; + /* We don't use a uint32x4_t pointer because it causes bus errors on ARMv7. */ + uint8_t const* xinput = (const uint8_t *) input; + uint8_t const* xsecret = (const uint8_t *) secret; + + size_t i; +#ifdef __wasm_simd128__ + /* + * On WASM SIMD128, Clang emits direct address loads when XXH3_kSecret + * is constant propagated, which results in it converting it to this + * inside the loop: + * + * a = v128.load(XXH3_kSecret + 0 + $secret_offset, offset = 0) + * b = v128.load(XXH3_kSecret + 16 + $secret_offset, offset = 0) + * ... + * + * This requires a full 32-bit address immediate (and therefore a 6 byte + * instruction) as well as an add for each offset. + * + * Putting an asm guard prevents it from folding (at the cost of losing + * the alignment hint), and uses the free offset in `v128.load` instead + * of adding secret_offset each time which overall reduces code size by + * about a kilobyte and improves performance. + */ + XXH_COMPILER_GUARD(xsecret); +#endif + /* Scalar lanes use the normal scalarRound routine */ + for (i = XXH3_NEON_LANES; i < XXH_ACC_NB; i++) { + XXH3_scalarRound(acc, input, secret, i); + } + i = 0; + /* 4 NEON lanes at a time. */ + for (; i+1 < XXH3_NEON_LANES / 2; i+=2) { + /* data_vec = xinput[i]; */ + uint64x2_t data_vec_1 = XXH_vld1q_u64(xinput + (i * 16)); + uint64x2_t data_vec_2 = XXH_vld1q_u64(xinput + ((i+1) * 16)); + /* key_vec = xsecret[i]; */ + uint64x2_t key_vec_1 = XXH_vld1q_u64(xsecret + (i * 16)); + uint64x2_t key_vec_2 = XXH_vld1q_u64(xsecret + ((i+1) * 16)); + /* data_swap = swap(data_vec) */ + uint64x2_t data_swap_1 = vextq_u64(data_vec_1, data_vec_1, 1); + uint64x2_t data_swap_2 = vextq_u64(data_vec_2, data_vec_2, 1); + /* data_key = data_vec ^ key_vec; */ + uint64x2_t data_key_1 = veorq_u64(data_vec_1, key_vec_1); + uint64x2_t data_key_2 = veorq_u64(data_vec_2, key_vec_2); + + /* + * If we reinterpret the 64x2 vectors as 32x4 vectors, we can use a + * de-interleave operation for 4 lanes in 1 step with `vuzpq_u32` to + * get one vector with the low 32 bits of each lane, and one vector + * with the high 32 bits of each lane. + * + * The intrinsic returns a double vector because the original ARMv7-a + * instruction modified both arguments in place. AArch64 and SIMD128 emit + * two instructions from this intrinsic. + * + * [ dk11L | dk11H | dk12L | dk12H ] -> [ dk11L | dk12L | dk21L | dk22L ] + * [ dk21L | dk21H | dk22L | dk22H ] -> [ dk11H | dk12H | dk21H | dk22H ] + */ + uint32x4x2_t unzipped = vuzpq_u32( + vreinterpretq_u32_u64(data_key_1), + vreinterpretq_u32_u64(data_key_2) + ); + /* data_key_lo = data_key & 0xFFFFFFFF */ + uint32x4_t data_key_lo = unzipped.val[0]; + /* data_key_hi = data_key >> 32 */ + uint32x4_t data_key_hi = unzipped.val[1]; + /* + * Then, we can split the vectors horizontally and multiply which, as for most + * widening intrinsics, have a variant that works on both high half vectors + * for free on AArch64. A similar instruction is available on SIMD128. + * + * sum = data_swap + (u64x2) data_key_lo * (u64x2) data_key_hi + */ + uint64x2_t sum_1 = XXH_vmlal_low_u32(data_swap_1, data_key_lo, data_key_hi); + uint64x2_t sum_2 = XXH_vmlal_high_u32(data_swap_2, data_key_lo, data_key_hi); + /* + * Clang reorders + * a += b * c; // umlal swap.2d, dkl.2s, dkh.2s + * c += a; // add acc.2d, acc.2d, swap.2d + * to + * c += a; // add acc.2d, acc.2d, swap.2d + * c += b * c; // umlal acc.2d, dkl.2s, dkh.2s + * + * While it would make sense in theory since the addition is faster, + * for reasons likely related to umlal being limited to certain NEON + * pipelines, this is worse. A compiler guard fixes this. + */ + XXH_COMPILER_GUARD_CLANG_NEON(sum_1); + XXH_COMPILER_GUARD_CLANG_NEON(sum_2); + /* xacc[i] = acc_vec + sum; */ + xacc[i] = vaddq_u64(xacc[i], sum_1); + xacc[i+1] = vaddq_u64(xacc[i+1], sum_2); + } + /* Operate on the remaining NEON lanes 2 at a time. */ + for (; i < XXH3_NEON_LANES / 2; i++) { + /* data_vec = xinput[i]; */ + uint64x2_t data_vec = XXH_vld1q_u64(xinput + (i * 16)); + /* key_vec = xsecret[i]; */ + uint64x2_t key_vec = XXH_vld1q_u64(xsecret + (i * 16)); + /* acc_vec_2 = swap(data_vec) */ + uint64x2_t data_swap = vextq_u64(data_vec, data_vec, 1); + /* data_key = data_vec ^ key_vec; */ + uint64x2_t data_key = veorq_u64(data_vec, key_vec); + /* For two lanes, just use VMOVN and VSHRN. */ + /* data_key_lo = data_key & 0xFFFFFFFF; */ + uint32x2_t data_key_lo = vmovn_u64(data_key); + /* data_key_hi = data_key >> 32; */ + uint32x2_t data_key_hi = vshrn_n_u64(data_key, 32); + /* sum = data_swap + (u64x2) data_key_lo * (u64x2) data_key_hi; */ + uint64x2_t sum = vmlal_u32(data_swap, data_key_lo, data_key_hi); + /* Same Clang workaround as before */ + XXH_COMPILER_GUARD_CLANG_NEON(sum); + /* xacc[i] = acc_vec + sum; */ + xacc[i] = vaddq_u64 (xacc[i], sum); + } + } +} +XXH_FORCE_INLINE XXH3_ACCUMULATE_TEMPLATE(neon) + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_neon(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + + { xxh_aliasing_uint64x2_t* xacc = (xxh_aliasing_uint64x2_t*) acc; + uint8_t const* xsecret = (uint8_t const*) secret; + + size_t i; + /* WASM uses operator overloads and doesn't need these. */ +#ifndef __wasm_simd128__ + /* { prime32_1, prime32_1 } */ + uint32x2_t const kPrimeLo = vdup_n_u32(XXH_PRIME32_1); + /* { 0, prime32_1, 0, prime32_1 } */ + uint32x4_t const kPrimeHi = vreinterpretq_u32_u64(vdupq_n_u64((xxh_u64)XXH_PRIME32_1 << 32)); +#endif + + /* AArch64 uses both scalar and neon at the same time */ + for (i = XXH3_NEON_LANES; i < XXH_ACC_NB; i++) { + XXH3_scalarScrambleRound(acc, secret, i); + } + for (i=0; i < XXH3_NEON_LANES / 2; i++) { + /* xacc[i] ^= (xacc[i] >> 47); */ + uint64x2_t acc_vec = xacc[i]; + uint64x2_t shifted = vshrq_n_u64(acc_vec, 47); + uint64x2_t data_vec = veorq_u64(acc_vec, shifted); + + /* xacc[i] ^= xsecret[i]; */ + uint64x2_t key_vec = XXH_vld1q_u64(xsecret + (i * 16)); + uint64x2_t data_key = veorq_u64(data_vec, key_vec); + /* xacc[i] *= XXH_PRIME32_1 */ +#ifdef __wasm_simd128__ + /* SIMD128 has multiply by u64x2, use it instead of expanding and scalarizing */ + xacc[i] = data_key * XXH_PRIME32_1; +#else + /* + * Expanded version with portable NEON intrinsics + * + * lo(x) * lo(y) + (hi(x) * lo(y) << 32) + * + * prod_hi = hi(data_key) * lo(prime) << 32 + * + * Since we only need 32 bits of this multiply a trick can be used, reinterpreting the vector + * as a uint32x4_t and multiplying by { 0, prime, 0, prime } to cancel out the unwanted bits + * and avoid the shift. + */ + uint32x4_t prod_hi = vmulq_u32 (vreinterpretq_u32_u64(data_key), kPrimeHi); + /* Extract low bits for vmlal_u32 */ + uint32x2_t data_key_lo = vmovn_u64(data_key); + /* xacc[i] = prod_hi + lo(data_key) * XXH_PRIME32_1; */ + xacc[i] = vmlal_u32(vreinterpretq_u64_u32(prod_hi), data_key_lo, kPrimeLo); +#endif + } + } +} +#endif + +#if (XXH_VECTOR == XXH_VSX) + +XXH_FORCE_INLINE void +XXH3_accumulate_512_vsx( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + /* presumed aligned */ + xxh_aliasing_u64x2* const xacc = (xxh_aliasing_u64x2*) acc; + xxh_u8 const* const xinput = (xxh_u8 const*) input; /* no alignment restriction */ + xxh_u8 const* const xsecret = (xxh_u8 const*) secret; /* no alignment restriction */ + xxh_u64x2 const v32 = { 32, 32 }; + size_t i; + for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) { + /* data_vec = xinput[i]; */ + xxh_u64x2 const data_vec = XXH_vec_loadu(xinput + 16*i); + /* key_vec = xsecret[i]; */ + xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + 16*i); + xxh_u64x2 const data_key = data_vec ^ key_vec; + /* shuffled = (data_key << 32) | (data_key >> 32); */ + xxh_u32x4 const shuffled = (xxh_u32x4)vec_rl(data_key, v32); + /* product = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)shuffled & 0xFFFFFFFF); */ + xxh_u64x2 const product = XXH_vec_mulo((xxh_u32x4)data_key, shuffled); + /* acc_vec = xacc[i]; */ + xxh_u64x2 acc_vec = xacc[i]; + acc_vec += product; + + /* swap high and low halves */ +#ifdef __s390x__ + acc_vec += vec_permi(data_vec, data_vec, 2); +#else + acc_vec += vec_xxpermdi(data_vec, data_vec, 2); +#endif + xacc[i] = acc_vec; + } +} +XXH_FORCE_INLINE XXH3_ACCUMULATE_TEMPLATE(vsx) + +XXH_FORCE_INLINE void +XXH3_scrambleAcc_vsx(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + XXH_ASSERT((((size_t)acc) & 15) == 0); + + { xxh_aliasing_u64x2* const xacc = (xxh_aliasing_u64x2*) acc; + const xxh_u8* const xsecret = (const xxh_u8*) secret; + /* constants */ + xxh_u64x2 const v32 = { 32, 32 }; + xxh_u64x2 const v47 = { 47, 47 }; + xxh_u32x4 const prime = { XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1, XXH_PRIME32_1 }; + size_t i; + for (i = 0; i < XXH_STRIPE_LEN / sizeof(xxh_u64x2); i++) { + /* xacc[i] ^= (xacc[i] >> 47); */ + xxh_u64x2 const acc_vec = xacc[i]; + xxh_u64x2 const data_vec = acc_vec ^ (acc_vec >> v47); + + /* xacc[i] ^= xsecret[i]; */ + xxh_u64x2 const key_vec = XXH_vec_loadu(xsecret + 16*i); + xxh_u64x2 const data_key = data_vec ^ key_vec; + + /* xacc[i] *= XXH_PRIME32_1 */ + /* prod_lo = ((xxh_u64x2)data_key & 0xFFFFFFFF) * ((xxh_u64x2)prime & 0xFFFFFFFF); */ + xxh_u64x2 const prod_even = XXH_vec_mule((xxh_u32x4)data_key, prime); + /* prod_hi = ((xxh_u64x2)data_key >> 32) * ((xxh_u64x2)prime >> 32); */ + xxh_u64x2 const prod_odd = XXH_vec_mulo((xxh_u32x4)data_key, prime); + xacc[i] = prod_odd + (prod_even << v32); + } } +} + +#endif + +#if (XXH_VECTOR == XXH_SVE) + +XXH_FORCE_INLINE void +XXH3_accumulate_512_sve( void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + uint64_t *xacc = (uint64_t *)acc; + const uint64_t *xinput = (const uint64_t *)(const void *)input; + const uint64_t *xsecret = (const uint64_t *)(const void *)secret; + svuint64_t kSwap = sveor_n_u64_z(svptrue_b64(), svindex_u64(0, 1), 1); + uint64_t element_count = svcntd(); + if (element_count >= 8) { + svbool_t mask = svptrue_pat_b64(SV_VL8); + svuint64_t vacc = svld1_u64(mask, xacc); + ACCRND(vacc, 0); + svst1_u64(mask, xacc, vacc); + } else if (element_count == 2) { /* sve128 */ + svbool_t mask = svptrue_pat_b64(SV_VL2); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 2); + svuint64_t acc2 = svld1_u64(mask, xacc + 4); + svuint64_t acc3 = svld1_u64(mask, xacc + 6); + ACCRND(acc0, 0); + ACCRND(acc1, 2); + ACCRND(acc2, 4); + ACCRND(acc3, 6); + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 2, acc1); + svst1_u64(mask, xacc + 4, acc2); + svst1_u64(mask, xacc + 6, acc3); + } else { + svbool_t mask = svptrue_pat_b64(SV_VL4); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 4); + ACCRND(acc0, 0); + ACCRND(acc1, 4); + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 4, acc1); + } +} + +XXH_FORCE_INLINE void +XXH3_accumulate_sve(xxh_u64* XXH_RESTRICT acc, + const xxh_u8* XXH_RESTRICT input, + const xxh_u8* XXH_RESTRICT secret, + size_t nbStripes) +{ + if (nbStripes != 0) { + uint64_t *xacc = (uint64_t *)acc; + const uint64_t *xinput = (const uint64_t *)(const void *)input; + const uint64_t *xsecret = (const uint64_t *)(const void *)secret; + svuint64_t kSwap = sveor_n_u64_z(svptrue_b64(), svindex_u64(0, 1), 1); + uint64_t element_count = svcntd(); + if (element_count >= 8) { + svbool_t mask = svptrue_pat_b64(SV_VL8); + svuint64_t vacc = svld1_u64(mask, xacc + 0); + do { + /* svprfd(svbool_t, void *, enum svfprop); */ + svprfd(mask, xinput + 128, SV_PLDL1STRM); + ACCRND(vacc, 0); + xinput += 8; + xsecret += 1; + nbStripes--; + } while (nbStripes != 0); + + svst1_u64(mask, xacc + 0, vacc); + } else if (element_count == 2) { /* sve128 */ + svbool_t mask = svptrue_pat_b64(SV_VL2); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 2); + svuint64_t acc2 = svld1_u64(mask, xacc + 4); + svuint64_t acc3 = svld1_u64(mask, xacc + 6); + do { + svprfd(mask, xinput + 128, SV_PLDL1STRM); + ACCRND(acc0, 0); + ACCRND(acc1, 2); + ACCRND(acc2, 4); + ACCRND(acc3, 6); + xinput += 8; + xsecret += 1; + nbStripes--; + } while (nbStripes != 0); + + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 2, acc1); + svst1_u64(mask, xacc + 4, acc2); + svst1_u64(mask, xacc + 6, acc3); + } else { + svbool_t mask = svptrue_pat_b64(SV_VL4); + svuint64_t acc0 = svld1_u64(mask, xacc + 0); + svuint64_t acc1 = svld1_u64(mask, xacc + 4); + do { + svprfd(mask, xinput + 128, SV_PLDL1STRM); + ACCRND(acc0, 0); + ACCRND(acc1, 4); + xinput += 8; + xsecret += 1; + nbStripes--; + } while (nbStripes != 0); + + svst1_u64(mask, xacc + 0, acc0); + svst1_u64(mask, xacc + 4, acc1); + } + } +} + +#endif + +/* scalar variants - universal */ + +#if defined(__aarch64__) && (defined(__GNUC__) || defined(__clang__)) +/* + * In XXH3_scalarRound(), GCC and Clang have a similar codegen issue, where they + * emit an excess mask and a full 64-bit multiply-add (MADD X-form). + * + * While this might not seem like much, as AArch64 is a 64-bit architecture, only + * big Cortex designs have a full 64-bit multiplier. + * + * On the little cores, the smaller 32-bit multiplier is used, and full 64-bit + * multiplies expand to 2-3 multiplies in microcode. This has a major penalty + * of up to 4 latency cycles and 2 stall cycles in the multiply pipeline. + * + * Thankfully, AArch64 still provides the 32-bit long multiply-add (UMADDL) which does + * not have this penalty and does the mask automatically. + */ +XXH_FORCE_INLINE xxh_u64 +XXH_mult32to64_add64(xxh_u64 lhs, xxh_u64 rhs, xxh_u64 acc) +{ + xxh_u64 ret; + /* note: %x = 64-bit register, %w = 32-bit register */ + __asm__("umaddl %x0, %w1, %w2, %x3" : "=r" (ret) : "r" (lhs), "r" (rhs), "r" (acc)); + return ret; +} +#else +XXH_FORCE_INLINE xxh_u64 +XXH_mult32to64_add64(xxh_u64 lhs, xxh_u64 rhs, xxh_u64 acc) +{ + return XXH_mult32to64((xxh_u32)lhs, (xxh_u32)rhs) + acc; +} +#endif + +/*! + * @internal + * @brief Scalar round for @ref XXH3_accumulate_512_scalar(). + * + * This is extracted to its own function because the NEON path uses a combination + * of NEON and scalar. + */ +XXH_FORCE_INLINE void +XXH3_scalarRound(void* XXH_RESTRICT acc, + void const* XXH_RESTRICT input, + void const* XXH_RESTRICT secret, + size_t lane) +{ + xxh_u64* xacc = (xxh_u64*) acc; + xxh_u8 const* xinput = (xxh_u8 const*) input; + xxh_u8 const* xsecret = (xxh_u8 const*) secret; + XXH_ASSERT(lane < XXH_ACC_NB); + XXH_ASSERT(((size_t)acc & (XXH_ACC_ALIGN-1)) == 0); + { + xxh_u64 const data_val = XXH_readLE64(xinput + lane * 8); + xxh_u64 const data_key = data_val ^ XXH_readLE64(xsecret + lane * 8); + xacc[lane ^ 1] += data_val; /* swap adjacent lanes */ + xacc[lane] = XXH_mult32to64_add64(data_key /* & 0xFFFFFFFF */, data_key >> 32, xacc[lane]); + } +} + +/*! + * @internal + * @brief Processes a 64 byte block of data using the scalar path. + */ +XXH_FORCE_INLINE void +XXH3_accumulate_512_scalar(void* XXH_RESTRICT acc, + const void* XXH_RESTRICT input, + const void* XXH_RESTRICT secret) +{ + size_t i; + /* ARM GCC refuses to unroll this loop, resulting in a 24% slowdown on ARMv6. */ +#if defined(__GNUC__) && !defined(__clang__) \ + && (defined(__arm__) || defined(__thumb2__)) \ + && defined(__ARM_FEATURE_UNALIGNED) /* no unaligned access just wastes bytes */ \ + && XXH_SIZE_OPT <= 0 +# pragma GCC unroll 8 +#endif + for (i=0; i < XXH_ACC_NB; i++) { + XXH3_scalarRound(acc, input, secret, i); + } +} +XXH_FORCE_INLINE XXH3_ACCUMULATE_TEMPLATE(scalar) + +/*! + * @internal + * @brief Scalar scramble step for @ref XXH3_scrambleAcc_scalar(). + * + * This is extracted to its own function because the NEON path uses a combination + * of NEON and scalar. + */ +XXH_FORCE_INLINE void +XXH3_scalarScrambleRound(void* XXH_RESTRICT acc, + void const* XXH_RESTRICT secret, + size_t lane) +{ + xxh_u64* const xacc = (xxh_u64*) acc; /* presumed aligned */ + const xxh_u8* const xsecret = (const xxh_u8*) secret; /* no alignment restriction */ + XXH_ASSERT((((size_t)acc) & (XXH_ACC_ALIGN-1)) == 0); + XXH_ASSERT(lane < XXH_ACC_NB); + { + xxh_u64 const key64 = XXH_readLE64(xsecret + lane * 8); + xxh_u64 acc64 = xacc[lane]; + acc64 = XXH_xorshift64(acc64, 47); + acc64 ^= key64; + acc64 *= XXH_PRIME32_1; + xacc[lane] = acc64; + } +} + +/*! + * @internal + * @brief Scrambles the accumulators after a large chunk has been read + */ +XXH_FORCE_INLINE void +XXH3_scrambleAcc_scalar(void* XXH_RESTRICT acc, const void* XXH_RESTRICT secret) +{ + size_t i; + for (i=0; i < XXH_ACC_NB; i++) { + XXH3_scalarScrambleRound(acc, secret, i); + } +} + +XXH_FORCE_INLINE void +XXH3_initCustomSecret_scalar(void* XXH_RESTRICT customSecret, xxh_u64 seed64) +{ + /* + * We need a separate pointer for the hack below, + * which requires a non-const pointer. + * Any decent compiler will optimize this out otherwise. + */ + const xxh_u8* kSecretPtr = XXH3_kSecret; + XXH_STATIC_ASSERT((XXH_SECRET_DEFAULT_SIZE & 15) == 0); + +#if defined(__GNUC__) && defined(__aarch64__) + /* + * UGLY HACK: + * GCC and Clang generate a bunch of MOV/MOVK pairs for aarch64, and they are + * placed sequentially, in order, at the top of the unrolled loop. + * + * While MOVK is great for generating constants (2 cycles for a 64-bit + * constant compared to 4 cycles for LDR), it fights for bandwidth with + * the arithmetic instructions. + * + * I L S + * MOVK + * MOVK + * MOVK + * MOVK + * ADD + * SUB STR + * STR + * By forcing loads from memory (as the asm line causes the compiler to assume + * that XXH3_kSecretPtr has been changed), the pipelines are used more + * efficiently: + * I L S + * LDR + * ADD LDR + * SUB STR + * STR + * + * See XXH3_NEON_LANES for details on the pipsline. + * + * XXH3_64bits_withSeed, len == 256, Snapdragon 835 + * without hack: 2654.4 MB/s + * with hack: 3202.9 MB/s + */ + XXH_COMPILER_GUARD(kSecretPtr); +#endif + { int const nbRounds = XXH_SECRET_DEFAULT_SIZE / 16; + int i; + for (i=0; i < nbRounds; i++) { + /* + * The asm hack causes the compiler to assume that kSecretPtr aliases with + * customSecret, and on aarch64, this prevented LDP from merging two + * loads together for free. Putting the loads together before the stores + * properly generates LDP. + */ + xxh_u64 lo = XXH_readLE64(kSecretPtr + 16*i) + seed64; + xxh_u64 hi = XXH_readLE64(kSecretPtr + 16*i + 8) - seed64; + XXH_writeLE64((xxh_u8*)customSecret + 16*i, lo); + XXH_writeLE64((xxh_u8*)customSecret + 16*i + 8, hi); + } } +} + + +typedef void (*XXH3_f_accumulate)(xxh_u64* XXH_RESTRICT, const xxh_u8* XXH_RESTRICT, const xxh_u8* XXH_RESTRICT, size_t); +typedef void (*XXH3_f_scrambleAcc)(void* XXH_RESTRICT, const void*); +typedef void (*XXH3_f_initCustomSecret)(void* XXH_RESTRICT, xxh_u64); + + +#if (XXH_VECTOR == XXH_AVX512) + +#define XXH3_accumulate_512 XXH3_accumulate_512_avx512 +#define XXH3_accumulate XXH3_accumulate_avx512 +#define XXH3_scrambleAcc XXH3_scrambleAcc_avx512 +#define XXH3_initCustomSecret XXH3_initCustomSecret_avx512 + +#elif (XXH_VECTOR == XXH_AVX2) + +#define XXH3_accumulate_512 XXH3_accumulate_512_avx2 +#define XXH3_accumulate XXH3_accumulate_avx2 +#define XXH3_scrambleAcc XXH3_scrambleAcc_avx2 +#define XXH3_initCustomSecret XXH3_initCustomSecret_avx2 + +#elif (XXH_VECTOR == XXH_SSE2) + +#define XXH3_accumulate_512 XXH3_accumulate_512_sse2 +#define XXH3_accumulate XXH3_accumulate_sse2 +#define XXH3_scrambleAcc XXH3_scrambleAcc_sse2 +#define XXH3_initCustomSecret XXH3_initCustomSecret_sse2 + +#elif (XXH_VECTOR == XXH_NEON) + +#define XXH3_accumulate_512 XXH3_accumulate_512_neon +#define XXH3_accumulate XXH3_accumulate_neon +#define XXH3_scrambleAcc XXH3_scrambleAcc_neon +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#elif (XXH_VECTOR == XXH_VSX) + +#define XXH3_accumulate_512 XXH3_accumulate_512_vsx +#define XXH3_accumulate XXH3_accumulate_vsx +#define XXH3_scrambleAcc XXH3_scrambleAcc_vsx +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#elif (XXH_VECTOR == XXH_SVE) +#define XXH3_accumulate_512 XXH3_accumulate_512_sve +#define XXH3_accumulate XXH3_accumulate_sve +#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#else /* scalar */ + +#define XXH3_accumulate_512 XXH3_accumulate_512_scalar +#define XXH3_accumulate XXH3_accumulate_scalar +#define XXH3_scrambleAcc XXH3_scrambleAcc_scalar +#define XXH3_initCustomSecret XXH3_initCustomSecret_scalar + +#endif + +#if XXH_SIZE_OPT >= 1 /* don't do SIMD for initialization */ +# undef XXH3_initCustomSecret +# define XXH3_initCustomSecret XXH3_initCustomSecret_scalar +#endif + +XXH_FORCE_INLINE void +XXH3_hashLong_internal_loop(xxh_u64* XXH_RESTRICT acc, + const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + size_t const nbStripesPerBlock = (secretSize - XXH_STRIPE_LEN) / XXH_SECRET_CONSUME_RATE; + size_t const block_len = XXH_STRIPE_LEN * nbStripesPerBlock; + size_t const nb_blocks = (len - 1) / block_len; + + size_t n; + + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); + + for (n = 0; n < nb_blocks; n++) { + f_acc(acc, input + n*block_len, secret, nbStripesPerBlock); + f_scramble(acc, secret + secretSize - XXH_STRIPE_LEN); + } + + /* last partial block */ + XXH_ASSERT(len > XXH_STRIPE_LEN); + { size_t const nbStripes = ((len - 1) - (block_len * nb_blocks)) / XXH_STRIPE_LEN; + XXH_ASSERT(nbStripes <= (secretSize / XXH_SECRET_CONSUME_RATE)); + f_acc(acc, input + nb_blocks*block_len, secret, nbStripes); + + /* last stripe */ + { const xxh_u8* const p = input + len - XXH_STRIPE_LEN; +#define XXH_SECRET_LASTACC_START 7 /* not aligned on 8, last secret is different from acc & scrambler */ + XXH3_accumulate_512(acc, p, secret + secretSize - XXH_STRIPE_LEN - XXH_SECRET_LASTACC_START); + } } +} + +XXH_FORCE_INLINE xxh_u64 +XXH3_mix2Accs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret) +{ + return XXH3_mul128_fold64( + acc[0] ^ XXH_readLE64(secret), + acc[1] ^ XXH_readLE64(secret+8) ); +} + +static XXH64_hash_t +XXH3_mergeAccs(const xxh_u64* XXH_RESTRICT acc, const xxh_u8* XXH_RESTRICT secret, xxh_u64 start) +{ + xxh_u64 result64 = start; + size_t i = 0; + + for (i = 0; i < 4; i++) { + result64 += XXH3_mix2Accs(acc+2*i, secret + 16*i); +#if defined(__clang__) /* Clang */ \ + && (defined(__arm__) || defined(__thumb__)) /* ARMv7 */ \ + && (defined(__ARM_NEON) || defined(__ARM_NEON__)) /* NEON */ \ + && !defined(XXH_ENABLE_AUTOVECTORIZE) /* Define to disable */ + /* + * UGLY HACK: + * Prevent autovectorization on Clang ARMv7-a. Exact same problem as + * the one in XXH3_len_129to240_64b. Speeds up shorter keys > 240b. + * XXH3_64bits, len == 256, Snapdragon 835: + * without hack: 2063.7 MB/s + * with hack: 2560.7 MB/s + */ + XXH_COMPILER_GUARD(result64); +#endif + } + + return XXH3_avalanche(result64); +} + +#define XXH3_INIT_ACC { XXH_PRIME32_3, XXH_PRIME64_1, XXH_PRIME64_2, XXH_PRIME64_3, \ + XXH_PRIME64_4, XXH_PRIME32_2, XXH_PRIME64_5, XXH_PRIME32_1 } + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_hashLong_64b_internal(const void* XXH_RESTRICT input, size_t len, + const void* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC; + + XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, (const xxh_u8*)secret, secretSize, f_acc, f_scramble); + + /* converge into final hash */ + XXH_STATIC_ASSERT(sizeof(acc) == 64); + /* do not align on 8, so that the secret is different from the accumulator */ +#define XXH_SECRET_MERGEACCS_START 11 + XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + return XXH3_mergeAccs(acc, (const xxh_u8*)secret + XXH_SECRET_MERGEACCS_START, (xxh_u64)len * XXH_PRIME64_1); +} + +/* + * It's important for performance to transmit secret's size (when it's static) + * so that the compiler can properly optimize the vectorized loop. + * This makes a big performance difference for "medium" keys (<1 KB) when using AVX instruction set. + * When the secret size is unknown, or on GCC 12 where the mix of NO_INLINE and FORCE_INLINE + * breaks -Og, this is XXH_NO_INLINE. + */ +XXH3_WITH_SECRET_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSecret(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; + return XXH3_hashLong_64b_internal(input, len, secret, secretLen, XXH3_accumulate, XXH3_scrambleAcc); +} + +/* + * It's preferable for performance that XXH3_hashLong is not inlined, + * as it results in a smaller function for small data, easier to the instruction cache. + * Note that inside this no_inline function, we do inline the internal loop, + * and provide a statically defined secret size to allow optimization of vector loop. + */ +XXH_NO_INLINE XXH_PUREF XXH64_hash_t +XXH3_hashLong_64b_default(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; (void)secret; (void)secretLen; + return XXH3_hashLong_64b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_accumulate, XXH3_scrambleAcc); +} + +/* + * XXH3_hashLong_64b_withSeed(): + * Generate a custom key based on alteration of default XXH3_kSecret with the seed, + * and then use this key for long mode hashing. + * + * This operation is decently fast but nonetheless costs a little bit of time. + * Try to avoid it whenever possible (typically when seed==0). + * + * It's important for performance that XXH3_hashLong is not inlined. Not sure + * why (uop cache maybe?), but the difference is large and easily measurable. + */ +XXH_FORCE_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSeed_internal(const void* input, size_t len, + XXH64_hash_t seed, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble, + XXH3_f_initCustomSecret f_initSec) +{ +#if XXH_SIZE_OPT <= 0 + if (seed == 0) + return XXH3_hashLong_64b_internal(input, len, + XXH3_kSecret, sizeof(XXH3_kSecret), + f_acc, f_scramble); +#endif + { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + f_initSec(secret, seed); + return XXH3_hashLong_64b_internal(input, len, secret, sizeof(secret), + f_acc, f_scramble); + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH64_hash_t +XXH3_hashLong_64b_withSeed(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed, const xxh_u8* XXH_RESTRICT secret, size_t secretLen) +{ + (void)secret; (void)secretLen; + return XXH3_hashLong_64b_withSeed_internal(input, len, seed, + XXH3_accumulate, XXH3_scrambleAcc, XXH3_initCustomSecret); +} + + +typedef XXH64_hash_t (*XXH3_hashLong64_f)(const void* XXH_RESTRICT, size_t, + XXH64_hash_t, const xxh_u8* XXH_RESTRICT, size_t); + +XXH_FORCE_INLINE XXH64_hash_t +XXH3_64bits_internal(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen, + XXH3_hashLong64_f f_hashLong) +{ + XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN); + /* + * If an action is to be taken if `secretLen` condition is not respected, + * it should be done here. + * For now, it's a contract pre-condition. + * Adding a check and a branch here would cost performance at every hash. + * Also, note that function signature doesn't offer room to return an error. + */ + if (len <= 16) + return XXH3_len_0to16_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64); + if (len <= 128) + return XXH3_len_17to128_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_len_129to240_64b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + return f_hashLong(input, len, seed64, (const xxh_u8*)secret, secretLen); +} + + +/* === Public entry point === */ + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits(XXH_NOESCAPE const void* input, size_t length) +{ + return XXH3_64bits_internal(input, length, 0, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_default); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSecret(XXH_NOESCAPE const void* input, size_t length, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + return XXH3_64bits_internal(input, length, 0, secret, secretSize, XXH3_hashLong_64b_withSecret); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSeed(XXH_NOESCAPE const void* input, size_t length, XXH64_hash_t seed) +{ + return XXH3_64bits_internal(input, length, seed, XXH3_kSecret, sizeof(XXH3_kSecret), XXH3_hashLong_64b_withSeed); +} + +XXH_PUBLIC_API XXH64_hash_t +XXH3_64bits_withSecretandSeed(XXH_NOESCAPE const void* input, size_t length, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed) +{ + if (length <= XXH3_MIDSIZE_MAX) + return XXH3_64bits_internal(input, length, seed, XXH3_kSecret, sizeof(XXH3_kSecret), NULL); + return XXH3_hashLong_64b_withSecret(input, length, seed, (const xxh_u8*)secret, secretSize); +} + + +/* === XXH3 streaming === */ +#ifndef XXH_NO_STREAM +/* + * Malloc's a pointer that is always aligned to align. + * + * This must be freed with `XXH_alignedFree()`. + * + * malloc typically guarantees 16 byte alignment on 64-bit systems and 8 byte + * alignment on 32-bit. This isn't enough for the 32 byte aligned loads in AVX2 + * or on 32-bit, the 16 byte aligned loads in SSE2 and NEON. + * + * This underalignment previously caused a rather obvious crash which went + * completely unnoticed due to XXH3_createState() not actually being tested. + * Credit to RedSpah for noticing this bug. + * + * The alignment is done manually: Functions like posix_memalign or _mm_malloc + * are avoided: To maintain portability, we would have to write a fallback + * like this anyways, and besides, testing for the existence of library + * functions without relying on external build tools is impossible. + * + * The method is simple: Overallocate, manually align, and store the offset + * to the original behind the returned pointer. + * + * Align must be a power of 2 and 8 <= align <= 128. + */ +static XXH_MALLOCF void* XXH_alignedMalloc(size_t s, size_t align) +{ + XXH_ASSERT(align <= 128 && align >= 8); /* range check */ + XXH_ASSERT((align & (align-1)) == 0); /* power of 2 */ + XXH_ASSERT(s != 0 && s < (s + align)); /* empty/overflow */ + { /* Overallocate to make room for manual realignment and an offset byte */ + xxh_u8* base = (xxh_u8*)XXH_malloc(s + align); + if (base != NULL) { + /* + * Get the offset needed to align this pointer. + * + * Even if the returned pointer is aligned, there will always be + * at least one byte to store the offset to the original pointer. + */ + size_t offset = align - ((size_t)base & (align - 1)); /* base % align */ + /* Add the offset for the now-aligned pointer */ + xxh_u8* ptr = base + offset; + + XXH_ASSERT((size_t)ptr % align == 0); + + /* Store the offset immediately before the returned pointer. */ + ptr[-1] = (xxh_u8)offset; + return ptr; + } + return NULL; + } +} +/* + * Frees an aligned pointer allocated by XXH_alignedMalloc(). Don't pass + * normal malloc'd pointers, XXH_alignedMalloc has a specific data layout. + */ +static void XXH_alignedFree(void* p) +{ + if (p != NULL) { + xxh_u8* ptr = (xxh_u8*)p; + /* Get the offset byte we added in XXH_malloc. */ + xxh_u8 offset = ptr[-1]; + /* Free the original malloc'd pointer */ + xxh_u8* base = ptr - offset; + XXH_free(base); + } +} +/*! @ingroup XXH3_family */ +/*! + * @brief Allocate an @ref XXH3_state_t. + * + * @return An allocated pointer of @ref XXH3_state_t on success. + * @return `NULL` on failure. + * + * @note Must be freed with XXH3_freeState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH3_state_t* XXH3_createState(void) +{ + XXH3_state_t* const state = (XXH3_state_t*)XXH_alignedMalloc(sizeof(XXH3_state_t), 64); + if (state==NULL) return NULL; + XXH3_INITSTATE(state); + return state; +} + +/*! @ingroup XXH3_family */ +/*! + * @brief Frees an @ref XXH3_state_t. + * + * @param statePtr A pointer to an @ref XXH3_state_t allocated with @ref XXH3_createState(). + * + * @return @ref XXH_OK. + * + * @note Must be allocated with XXH3_createState(). + * + * @see @ref streaming_example "Streaming Example" + */ +XXH_PUBLIC_API XXH_errorcode XXH3_freeState(XXH3_state_t* statePtr) +{ + XXH_alignedFree(statePtr); + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API void +XXH3_copyState(XXH_NOESCAPE XXH3_state_t* dst_state, XXH_NOESCAPE const XXH3_state_t* src_state) +{ + XXH_memcpy(dst_state, src_state, sizeof(*dst_state)); +} + +static void +XXH3_reset_internal(XXH3_state_t* statePtr, + XXH64_hash_t seed, + const void* secret, size_t secretSize) +{ + size_t const initStart = offsetof(XXH3_state_t, bufferedSize); + size_t const initLength = offsetof(XXH3_state_t, nbStripesPerBlock) - initStart; + XXH_ASSERT(offsetof(XXH3_state_t, nbStripesPerBlock) > initStart); + XXH_ASSERT(statePtr != NULL); + /* set members from bufferedSize to nbStripesPerBlock (excluded) to 0 */ + memset((char*)statePtr + initStart, 0, initLength); + statePtr->acc[0] = XXH_PRIME32_3; + statePtr->acc[1] = XXH_PRIME64_1; + statePtr->acc[2] = XXH_PRIME64_2; + statePtr->acc[3] = XXH_PRIME64_3; + statePtr->acc[4] = XXH_PRIME64_4; + statePtr->acc[5] = XXH_PRIME32_2; + statePtr->acc[6] = XXH_PRIME64_5; + statePtr->acc[7] = XXH_PRIME32_1; + statePtr->seed = seed; + statePtr->useSeed = (seed != 0); + statePtr->extSecret = (const unsigned char*)secret; + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); + statePtr->secretLimit = secretSize - XXH_STRIPE_LEN; + statePtr->nbStripesPerBlock = statePtr->secretLimit / XXH_SECRET_CONSUME_RATE; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_reset_internal(statePtr, 0, XXH3_kSecret, XXH_SECRET_DEFAULT_SIZE); + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + if (statePtr == NULL) return XXH_ERROR; + XXH3_reset_internal(statePtr, 0, secret, secretSize); + if (secret == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed) +{ + if (statePtr == NULL) return XXH_ERROR; + if (seed==0) return XXH3_64bits_reset(statePtr); + if ((seed != statePtr->seed) || (statePtr->extSecret != NULL)) + XXH3_initCustomSecret(statePtr->customSecret, seed); + XXH3_reset_internal(statePtr, seed, NULL, XXH_SECRET_DEFAULT_SIZE); + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed64) +{ + if (statePtr == NULL) return XXH_ERROR; + if (secret == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; + XXH3_reset_internal(statePtr, seed64, secret, secretSize); + statePtr->useSeed = 1; /* always, even if seed64==0 */ + return XXH_OK; +} + +/*! + * @internal + * @brief Processes a large input for XXH3_update() and XXH3_digest_long(). + * + * Unlike XXH3_hashLong_internal_loop(), this can process data that overlaps a block. + * + * @param acc Pointer to the 8 accumulator lanes + * @param nbStripesSoFarPtr In/out pointer to the number of leftover stripes in the block* + * @param nbStripesPerBlock Number of stripes in a block + * @param input Input pointer + * @param nbStripes Number of stripes to process + * @param secret Secret pointer + * @param secretLimit Offset of the last block in @p secret + * @param f_acc Pointer to an XXH3_accumulate implementation + * @param f_scramble Pointer to an XXH3_scrambleAcc implementation + * @return Pointer past the end of @p input after processing + */ +XXH_FORCE_INLINE const xxh_u8 * +XXH3_consumeStripes(xxh_u64* XXH_RESTRICT acc, + size_t* XXH_RESTRICT nbStripesSoFarPtr, size_t nbStripesPerBlock, + const xxh_u8* XXH_RESTRICT input, size_t nbStripes, + const xxh_u8* XXH_RESTRICT secret, size_t secretLimit, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + const xxh_u8* initialSecret = secret + *nbStripesSoFarPtr * XXH_SECRET_CONSUME_RATE; + /* Process full blocks */ + if (nbStripes >= (nbStripesPerBlock - *nbStripesSoFarPtr)) { + /* Process the initial partial block... */ + size_t nbStripesThisIter = nbStripesPerBlock - *nbStripesSoFarPtr; + + do { + /* Accumulate and scramble */ + f_acc(acc, input, initialSecret, nbStripesThisIter); + f_scramble(acc, secret + secretLimit); + input += nbStripesThisIter * XXH_STRIPE_LEN; + nbStripes -= nbStripesThisIter; + /* Then continue the loop with the full block size */ + nbStripesThisIter = nbStripesPerBlock; + initialSecret = secret; + } while (nbStripes >= nbStripesPerBlock); + *nbStripesSoFarPtr = 0; + } + /* Process a partial block */ + if (nbStripes > 0) { + f_acc(acc, input, initialSecret, nbStripes); + input += nbStripes * XXH_STRIPE_LEN; + *nbStripesSoFarPtr += nbStripes; + } + /* Return end pointer */ + return input; +} + +#ifndef XXH3_STREAM_USE_STACK +# if XXH_SIZE_OPT <= 0 && !defined(__clang__) /* clang doesn't need additional stack space */ +# define XXH3_STREAM_USE_STACK 1 +# endif +#endif +/* + * Both XXH3_64bits_update and XXH3_128bits_update use this routine. + */ +XXH_FORCE_INLINE XXH_errorcode +XXH3_update(XXH3_state_t* XXH_RESTRICT const state, + const xxh_u8* XXH_RESTRICT input, size_t len, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + if (input==NULL) { + XXH_ASSERT(len == 0); + return XXH_OK; + } + + XXH_ASSERT(state != NULL); + { const xxh_u8* const bEnd = input + len; + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; +#if defined(XXH3_STREAM_USE_STACK) && XXH3_STREAM_USE_STACK >= 1 + /* For some reason, gcc and MSVC seem to suffer greatly + * when operating accumulators directly into state. + * Operating into stack space seems to enable proper optimization. + * clang, on the other hand, doesn't seem to need this trick */ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[8]; + XXH_memcpy(acc, state->acc, sizeof(acc)); +#else + xxh_u64* XXH_RESTRICT const acc = state->acc; +#endif + state->totalLen += len; + XXH_ASSERT(state->bufferedSize <= XXH3_INTERNALBUFFER_SIZE); + + /* small input : just fill in tmp buffer */ + if (len <= XXH3_INTERNALBUFFER_SIZE - state->bufferedSize) { + XXH_memcpy(state->buffer + state->bufferedSize, input, len); + state->bufferedSize += (XXH32_hash_t)len; + return XXH_OK; + } + + /* total input is now > XXH3_INTERNALBUFFER_SIZE */ + #define XXH3_INTERNALBUFFER_STRIPES (XXH3_INTERNALBUFFER_SIZE / XXH_STRIPE_LEN) + XXH_STATIC_ASSERT(XXH3_INTERNALBUFFER_SIZE % XXH_STRIPE_LEN == 0); /* clean multiple */ + + /* + * Internal buffer is partially filled (always, except at beginning) + * Complete it, then consume it. + */ + if (state->bufferedSize) { + size_t const loadSize = XXH3_INTERNALBUFFER_SIZE - state->bufferedSize; + XXH_memcpy(state->buffer + state->bufferedSize, input, loadSize); + input += loadSize; + XXH3_consumeStripes(acc, + &state->nbStripesSoFar, state->nbStripesPerBlock, + state->buffer, XXH3_INTERNALBUFFER_STRIPES, + secret, state->secretLimit, + f_acc, f_scramble); + state->bufferedSize = 0; + } + XXH_ASSERT(input < bEnd); + if (bEnd - input > XXH3_INTERNALBUFFER_SIZE) { + size_t nbStripes = (size_t)(bEnd - 1 - input) / XXH_STRIPE_LEN; + input = XXH3_consumeStripes(acc, + &state->nbStripesSoFar, state->nbStripesPerBlock, + input, nbStripes, + secret, state->secretLimit, + f_acc, f_scramble); + XXH_memcpy(state->buffer + sizeof(state->buffer) - XXH_STRIPE_LEN, input - XXH_STRIPE_LEN, XXH_STRIPE_LEN); + + } + /* Some remaining input (always) : buffer it */ + XXH_ASSERT(input < bEnd); + XXH_ASSERT(bEnd - input <= XXH3_INTERNALBUFFER_SIZE); + XXH_ASSERT(state->bufferedSize == 0); + XXH_memcpy(state->buffer, input, (size_t)(bEnd-input)); + state->bufferedSize = (XXH32_hash_t)(bEnd-input); +#if defined(XXH3_STREAM_USE_STACK) && XXH3_STREAM_USE_STACK >= 1 + /* save stack accumulators into state */ + XXH_memcpy(state->acc, acc, sizeof(acc)); +#endif + } + + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_64bits_update(XXH_NOESCAPE XXH3_state_t* state, XXH_NOESCAPE const void* input, size_t len) +{ + return XXH3_update(state, (const xxh_u8*)input, len, + XXH3_accumulate, XXH3_scrambleAcc); +} + + +XXH_FORCE_INLINE void +XXH3_digest_long (XXH64_hash_t* acc, + const XXH3_state_t* state, + const unsigned char* secret) +{ + xxh_u8 lastStripe[XXH_STRIPE_LEN]; + const xxh_u8* lastStripePtr; + + /* + * Digest on a local copy. This way, the state remains unaltered, and it can + * continue ingesting more input afterwards. + */ + XXH_memcpy(acc, state->acc, sizeof(state->acc)); + if (state->bufferedSize >= XXH_STRIPE_LEN) { + /* Consume remaining stripes then point to remaining data in buffer */ + size_t const nbStripes = (state->bufferedSize - 1) / XXH_STRIPE_LEN; + size_t nbStripesSoFar = state->nbStripesSoFar; + XXH3_consumeStripes(acc, + &nbStripesSoFar, state->nbStripesPerBlock, + state->buffer, nbStripes, + secret, state->secretLimit, + XXH3_accumulate, XXH3_scrambleAcc); + lastStripePtr = state->buffer + state->bufferedSize - XXH_STRIPE_LEN; + } else { /* bufferedSize < XXH_STRIPE_LEN */ + /* Copy to temp buffer */ + size_t const catchupSize = XXH_STRIPE_LEN - state->bufferedSize; + XXH_ASSERT(state->bufferedSize > 0); /* there is always some input buffered */ + XXH_memcpy(lastStripe, state->buffer + sizeof(state->buffer) - catchupSize, catchupSize); + XXH_memcpy(lastStripe + catchupSize, state->buffer, state->bufferedSize); + lastStripePtr = lastStripe; + } + /* Last stripe */ + XXH3_accumulate_512(acc, + lastStripePtr, + secret + state->secretLimit - XXH_SECRET_LASTACC_START); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH64_hash_t XXH3_64bits_digest (XXH_NOESCAPE const XXH3_state_t* state) +{ + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + if (state->totalLen > XXH3_MIDSIZE_MAX) { + XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB]; + XXH3_digest_long(acc, state, secret); + return XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)state->totalLen * XXH_PRIME64_1); + } + /* totalLen <= XXH3_MIDSIZE_MAX: digesting a short input */ + if (state->useSeed) + return XXH3_64bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed); + return XXH3_64bits_withSecret(state->buffer, (size_t)(state->totalLen), + secret, state->secretLimit + XXH_STRIPE_LEN); +} +#endif /* !XXH_NO_STREAM */ + + +/* ========================================== + * XXH3 128 bits (a.k.a XXH128) + * ========================================== + * XXH3's 128-bit variant has better mixing and strength than the 64-bit variant, + * even without counting the significantly larger output size. + * + * For example, extra steps are taken to avoid the seed-dependent collisions + * in 17-240 byte inputs (See XXH3_mix16B and XXH128_mix32B). + * + * This strength naturally comes at the cost of some speed, especially on short + * lengths. Note that longer hashes are about as fast as the 64-bit version + * due to it using only a slight modification of the 64-bit loop. + * + * XXH128 is also more oriented towards 64-bit machines. It is still extremely + * fast for a _128-bit_ hash on 32-bit (it usually clears XXH64). + */ + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_1to3_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + /* A doubled version of 1to3_64b with different constants. */ + XXH_ASSERT(input != NULL); + XXH_ASSERT(1 <= len && len <= 3); + XXH_ASSERT(secret != NULL); + /* + * len = 1: combinedl = { input[0], 0x01, input[0], input[0] } + * len = 2: combinedl = { input[1], 0x02, input[0], input[1] } + * len = 3: combinedl = { input[2], 0x03, input[0], input[1] } + */ + { xxh_u8 const c1 = input[0]; + xxh_u8 const c2 = input[len >> 1]; + xxh_u8 const c3 = input[len - 1]; + xxh_u32 const combinedl = ((xxh_u32)c1 <<16) | ((xxh_u32)c2 << 24) + | ((xxh_u32)c3 << 0) | ((xxh_u32)len << 8); + xxh_u32 const combinedh = XXH_rotl32(XXH_swap32(combinedl), 13); + xxh_u64 const bitflipl = (XXH_readLE32(secret) ^ XXH_readLE32(secret+4)) + seed; + xxh_u64 const bitfliph = (XXH_readLE32(secret+8) ^ XXH_readLE32(secret+12)) - seed; + xxh_u64 const keyed_lo = (xxh_u64)combinedl ^ bitflipl; + xxh_u64 const keyed_hi = (xxh_u64)combinedh ^ bitfliph; + XXH128_hash_t h128; + h128.low64 = XXH64_avalanche(keyed_lo); + h128.high64 = XXH64_avalanche(keyed_hi); + return h128; + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_4to8_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(4 <= len && len <= 8); + seed ^= (xxh_u64)XXH_swap32((xxh_u32)seed) << 32; + { xxh_u32 const input_lo = XXH_readLE32(input); + xxh_u32 const input_hi = XXH_readLE32(input + len - 4); + xxh_u64 const input_64 = input_lo + ((xxh_u64)input_hi << 32); + xxh_u64 const bitflip = (XXH_readLE64(secret+16) ^ XXH_readLE64(secret+24)) + seed; + xxh_u64 const keyed = input_64 ^ bitflip; + + /* Shift len to the left to ensure it is even, this avoids even multiplies. */ + XXH128_hash_t m128 = XXH_mult64to128(keyed, XXH_PRIME64_1 + (len << 2)); + + m128.high64 += (m128.low64 << 1); + m128.low64 ^= (m128.high64 >> 3); + + m128.low64 = XXH_xorshift64(m128.low64, 35); + m128.low64 *= PRIME_MX2; + m128.low64 = XXH_xorshift64(m128.low64, 28); + m128.high64 = XXH3_avalanche(m128.high64); + return m128; + } +} + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_9to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(input != NULL); + XXH_ASSERT(secret != NULL); + XXH_ASSERT(9 <= len && len <= 16); + { xxh_u64 const bitflipl = (XXH_readLE64(secret+32) ^ XXH_readLE64(secret+40)) - seed; + xxh_u64 const bitfliph = (XXH_readLE64(secret+48) ^ XXH_readLE64(secret+56)) + seed; + xxh_u64 const input_lo = XXH_readLE64(input); + xxh_u64 input_hi = XXH_readLE64(input + len - 8); + XXH128_hash_t m128 = XXH_mult64to128(input_lo ^ input_hi ^ bitflipl, XXH_PRIME64_1); + /* + * Put len in the middle of m128 to ensure that the length gets mixed to + * both the low and high bits in the 128x64 multiply below. + */ + m128.low64 += (xxh_u64)(len - 1) << 54; + input_hi ^= bitfliph; + /* + * Add the high 32 bits of input_hi to the high 32 bits of m128, then + * add the long product of the low 32 bits of input_hi and XXH_PRIME32_2 to + * the high 64 bits of m128. + * + * The best approach to this operation is different on 32-bit and 64-bit. + */ + if (sizeof(void *) < sizeof(xxh_u64)) { /* 32-bit */ + /* + * 32-bit optimized version, which is more readable. + * + * On 32-bit, it removes an ADC and delays a dependency between the two + * halves of m128.high64, but it generates an extra mask on 64-bit. + */ + m128.high64 += (input_hi & 0xFFFFFFFF00000000ULL) + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2); + } else { + /* + * 64-bit optimized (albeit more confusing) version. + * + * Uses some properties of addition and multiplication to remove the mask: + * + * Let: + * a = input_hi.lo = (input_hi & 0x00000000FFFFFFFF) + * b = input_hi.hi = (input_hi & 0xFFFFFFFF00000000) + * c = XXH_PRIME32_2 + * + * a + (b * c) + * Inverse Property: x + y - x == y + * a + (b * (1 + c - 1)) + * Distributive Property: x * (y + z) == (x * y) + (x * z) + * a + (b * 1) + (b * (c - 1)) + * Identity Property: x * 1 == x + * a + b + (b * (c - 1)) + * + * Substitute a, b, and c: + * input_hi.hi + input_hi.lo + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1)) + * + * Since input_hi.hi + input_hi.lo == input_hi, we get this: + * input_hi + ((xxh_u64)input_hi.lo * (XXH_PRIME32_2 - 1)) + */ + m128.high64 += input_hi + XXH_mult32to64((xxh_u32)input_hi, XXH_PRIME32_2 - 1); + } + /* m128 ^= XXH_swap64(m128 >> 64); */ + m128.low64 ^= XXH_swap64(m128.high64); + + { /* 128x64 multiply: h128 = m128 * XXH_PRIME64_2; */ + XXH128_hash_t h128 = XXH_mult64to128(m128.low64, XXH_PRIME64_2); + h128.high64 += m128.high64 * XXH_PRIME64_2; + + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = XXH3_avalanche(h128.high64); + return h128; + } } +} + +/* + * Assumption: `secret` size is >= XXH3_SECRET_SIZE_MIN + */ +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_0to16_128b(const xxh_u8* input, size_t len, const xxh_u8* secret, XXH64_hash_t seed) +{ + XXH_ASSERT(len <= 16); + { if (len > 8) return XXH3_len_9to16_128b(input, len, secret, seed); + if (len >= 4) return XXH3_len_4to8_128b(input, len, secret, seed); + if (len) return XXH3_len_1to3_128b(input, len, secret, seed); + { XXH128_hash_t h128; + xxh_u64 const bitflipl = XXH_readLE64(secret+64) ^ XXH_readLE64(secret+72); + xxh_u64 const bitfliph = XXH_readLE64(secret+80) ^ XXH_readLE64(secret+88); + h128.low64 = XXH64_avalanche(seed ^ bitflipl); + h128.high64 = XXH64_avalanche( seed ^ bitfliph); + return h128; + } } +} + +/* + * A bit slower than XXH3_mix16B, but handles multiply by zero better. + */ +XXH_FORCE_INLINE XXH128_hash_t +XXH128_mix32B(XXH128_hash_t acc, const xxh_u8* input_1, const xxh_u8* input_2, + const xxh_u8* secret, XXH64_hash_t seed) +{ + acc.low64 += XXH3_mix16B (input_1, secret+0, seed); + acc.low64 ^= XXH_readLE64(input_2) + XXH_readLE64(input_2 + 8); + acc.high64 += XXH3_mix16B (input_2, secret+16, seed); + acc.high64 ^= XXH_readLE64(input_1) + XXH_readLE64(input_1 + 8); + return acc; +} + + +XXH_FORCE_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_17to128_128b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(16 < len && len <= 128); + + { XXH128_hash_t acc; + acc.low64 = len * XXH_PRIME64_1; + acc.high64 = 0; + +#if XXH_SIZE_OPT >= 1 + { + /* Smaller, but slightly slower. */ + unsigned int i = (unsigned int)(len - 1) / 32; + do { + acc = XXH128_mix32B(acc, input+16*i, input+len-16*(i+1), secret+32*i, seed); + } while (i-- != 0); + } +#else + if (len > 32) { + if (len > 64) { + if (len > 96) { + acc = XXH128_mix32B(acc, input+48, input+len-64, secret+96, seed); + } + acc = XXH128_mix32B(acc, input+32, input+len-48, secret+64, seed); + } + acc = XXH128_mix32B(acc, input+16, input+len-32, secret+32, seed); + } + acc = XXH128_mix32B(acc, input, input+len-16, secret, seed); +#endif + { XXH128_hash_t h128; + h128.low64 = acc.low64 + acc.high64; + h128.high64 = (acc.low64 * XXH_PRIME64_1) + + (acc.high64 * XXH_PRIME64_4) + + ((len - seed) * XXH_PRIME64_2); + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64); + return h128; + } + } +} + +XXH_NO_INLINE XXH_PUREF XXH128_hash_t +XXH3_len_129to240_128b(const xxh_u8* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH64_hash_t seed) +{ + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); (void)secretSize; + XXH_ASSERT(128 < len && len <= XXH3_MIDSIZE_MAX); + + { XXH128_hash_t acc; + unsigned i; + acc.low64 = len * XXH_PRIME64_1; + acc.high64 = 0; + /* + * We set as `i` as offset + 32. We do this so that unchanged + * `len` can be used as upper bound. This reaches a sweet spot + * where both x86 and aarch64 get simple agen and good codegen + * for the loop. + */ + for (i = 32; i < 160; i += 32) { + acc = XXH128_mix32B(acc, + input + i - 32, + input + i - 16, + secret + i - 32, + seed); + } + acc.low64 = XXH3_avalanche(acc.low64); + acc.high64 = XXH3_avalanche(acc.high64); + /* + * NB: `i <= len` will duplicate the last 32-bytes if + * len % 32 was zero. This is an unfortunate necessity to keep + * the hash result stable. + */ + for (i=160; i <= len; i += 32) { + acc = XXH128_mix32B(acc, + input + i - 32, + input + i - 16, + secret + XXH3_MIDSIZE_STARTOFFSET + i - 160, + seed); + } + /* last bytes */ + acc = XXH128_mix32B(acc, + input + len - 16, + input + len - 32, + secret + XXH3_SECRET_SIZE_MIN - XXH3_MIDSIZE_LASTOFFSET - 16, + (XXH64_hash_t)0 - seed); + + { XXH128_hash_t h128; + h128.low64 = acc.low64 + acc.high64; + h128.high64 = (acc.low64 * XXH_PRIME64_1) + + (acc.high64 * XXH_PRIME64_4) + + ((len - seed) * XXH_PRIME64_2); + h128.low64 = XXH3_avalanche(h128.low64); + h128.high64 = (XXH64_hash_t)0 - XXH3_avalanche(h128.high64); + return h128; + } + } +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_hashLong_128b_internal(const void* XXH_RESTRICT input, size_t len, + const xxh_u8* XXH_RESTRICT secret, size_t secretSize, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble) +{ + XXH_ALIGN(XXH_ACC_ALIGN) xxh_u64 acc[XXH_ACC_NB] = XXH3_INIT_ACC; + + XXH3_hashLong_internal_loop(acc, (const xxh_u8*)input, len, secret, secretSize, f_acc, f_scramble); + + /* converge into final hash */ + XXH_STATIC_ASSERT(sizeof(acc) == 64); + XXH_ASSERT(secretSize >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + { XXH128_hash_t h128; + h128.low64 = XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)len * XXH_PRIME64_1); + h128.high64 = XXH3_mergeAccs(acc, + secret + secretSize + - sizeof(acc) - XXH_SECRET_MERGEACCS_START, + ~((xxh_u64)len * XXH_PRIME64_2)); + return h128; + } +} + +/* + * It's important for performance that XXH3_hashLong() is not inlined. + */ +XXH_NO_INLINE XXH_PUREF XXH128_hash_t +XXH3_hashLong_128b_default(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; (void)secret; (void)secretLen; + return XXH3_hashLong_128b_internal(input, len, XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_accumulate, XXH3_scrambleAcc); +} + +/* + * It's important for performance to pass @p secretLen (when it's static) + * to the compiler, so that it can properly optimize the vectorized loop. + * + * When the secret size is unknown, or on GCC 12 where the mix of NO_INLINE and FORCE_INLINE + * breaks -Og, this is XXH_NO_INLINE. + */ +XXH3_WITH_SECRET_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSecret(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)seed64; + return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, secretLen, + XXH3_accumulate, XXH3_scrambleAcc); +} + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSeed_internal(const void* XXH_RESTRICT input, size_t len, + XXH64_hash_t seed64, + XXH3_f_accumulate f_acc, + XXH3_f_scrambleAcc f_scramble, + XXH3_f_initCustomSecret f_initSec) +{ + if (seed64 == 0) + return XXH3_hashLong_128b_internal(input, len, + XXH3_kSecret, sizeof(XXH3_kSecret), + f_acc, f_scramble); + { XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + f_initSec(secret, seed64); + return XXH3_hashLong_128b_internal(input, len, (const xxh_u8*)secret, sizeof(secret), + f_acc, f_scramble); + } +} + +/* + * It's important for performance that XXH3_hashLong is not inlined. + */ +XXH_NO_INLINE XXH128_hash_t +XXH3_hashLong_128b_withSeed(const void* input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen) +{ + (void)secret; (void)secretLen; + return XXH3_hashLong_128b_withSeed_internal(input, len, seed64, + XXH3_accumulate, XXH3_scrambleAcc, XXH3_initCustomSecret); +} + +typedef XXH128_hash_t (*XXH3_hashLong128_f)(const void* XXH_RESTRICT, size_t, + XXH64_hash_t, const void* XXH_RESTRICT, size_t); + +XXH_FORCE_INLINE XXH128_hash_t +XXH3_128bits_internal(const void* input, size_t len, + XXH64_hash_t seed64, const void* XXH_RESTRICT secret, size_t secretLen, + XXH3_hashLong128_f f_hl128) +{ + XXH_ASSERT(secretLen >= XXH3_SECRET_SIZE_MIN); + /* + * If an action is to be taken if `secret` conditions are not respected, + * it should be done here. + * For now, it's a contract pre-condition. + * Adding a check and a branch here would cost performance at every hash. + */ + if (len <= 16) + return XXH3_len_0to16_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, seed64); + if (len <= 128) + return XXH3_len_17to128_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_len_129to240_128b((const xxh_u8*)input, len, (const xxh_u8*)secret, secretLen, seed64); + return f_hl128(input, len, seed64, secret, secretLen); +} + + +/* === Public XXH128 API === */ + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits(XXH_NOESCAPE const void* input, size_t len) +{ + return XXH3_128bits_internal(input, len, 0, + XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_hashLong_128b_default); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSecret(XXH_NOESCAPE const void* input, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + return XXH3_128bits_internal(input, len, 0, + (const xxh_u8*)secret, secretSize, + XXH3_hashLong_128b_withSecret); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSeed(XXH_NOESCAPE const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_128bits_internal(input, len, seed, + XXH3_kSecret, sizeof(XXH3_kSecret), + XXH3_hashLong_128b_withSeed); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH3_128bits_withSecretandSeed(XXH_NOESCAPE const void* input, size_t len, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed) +{ + if (len <= XXH3_MIDSIZE_MAX) + return XXH3_128bits_internal(input, len, seed, XXH3_kSecret, sizeof(XXH3_kSecret), NULL); + return XXH3_hashLong_128b_withSecret(input, len, seed, secret, secretSize); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH128(XXH_NOESCAPE const void* input, size_t len, XXH64_hash_t seed) +{ + return XXH3_128bits_withSeed(input, len, seed); +} + + +/* === XXH3 128-bit streaming === */ +#ifndef XXH_NO_STREAM +/* + * All initialization and update functions are identical to 64-bit streaming variant. + * The only difference is the finalization routine. + */ + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset(XXH_NOESCAPE XXH3_state_t* statePtr) +{ + return XXH3_64bits_reset(statePtr); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecret(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize) +{ + return XXH3_64bits_reset_withSecret(statePtr, secret, secretSize); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH64_hash_t seed) +{ + return XXH3_64bits_reset_withSeed(statePtr, seed); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_reset_withSecretandSeed(XXH_NOESCAPE XXH3_state_t* statePtr, XXH_NOESCAPE const void* secret, size_t secretSize, XXH64_hash_t seed) +{ + return XXH3_64bits_reset_withSecretandSeed(statePtr, secret, secretSize, seed); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_128bits_update(XXH_NOESCAPE XXH3_state_t* state, XXH_NOESCAPE const void* input, size_t len) +{ + return XXH3_64bits_update(state, input, len); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t XXH3_128bits_digest (XXH_NOESCAPE const XXH3_state_t* state) +{ + const unsigned char* const secret = (state->extSecret == NULL) ? state->customSecret : state->extSecret; + if (state->totalLen > XXH3_MIDSIZE_MAX) { + XXH_ALIGN(XXH_ACC_ALIGN) XXH64_hash_t acc[XXH_ACC_NB]; + XXH3_digest_long(acc, state, secret); + XXH_ASSERT(state->secretLimit + XXH_STRIPE_LEN >= sizeof(acc) + XXH_SECRET_MERGEACCS_START); + { XXH128_hash_t h128; + h128.low64 = XXH3_mergeAccs(acc, + secret + XXH_SECRET_MERGEACCS_START, + (xxh_u64)state->totalLen * XXH_PRIME64_1); + h128.high64 = XXH3_mergeAccs(acc, + secret + state->secretLimit + XXH_STRIPE_LEN + - sizeof(acc) - XXH_SECRET_MERGEACCS_START, + ~((xxh_u64)state->totalLen * XXH_PRIME64_2)); + return h128; + } + } + /* len <= XXH3_MIDSIZE_MAX : short code */ + if (state->useSeed) + return XXH3_128bits_withSeed(state->buffer, (size_t)state->totalLen, state->seed); + return XXH3_128bits_withSecret(state->buffer, (size_t)(state->totalLen), + secret, state->secretLimit + XXH_STRIPE_LEN); +} +#endif /* !XXH_NO_STREAM */ +/* 128-bit utility functions */ + +#include <string.h> /* memcmp, memcpy */ + +/* return : 1 is equal, 0 if different */ +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API int XXH128_isEqual(XXH128_hash_t h1, XXH128_hash_t h2) +{ + /* note : XXH128_hash_t is compact, it has no padding byte */ + return !(memcmp(&h1, &h2, sizeof(h1))); +} + +/* This prototype is compatible with stdlib's qsort(). + * @return : >0 if *h128_1 > *h128_2 + * <0 if *h128_1 < *h128_2 + * =0 if *h128_1 == *h128_2 */ +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API int XXH128_cmp(XXH_NOESCAPE const void* h128_1, XXH_NOESCAPE const void* h128_2) +{ + XXH128_hash_t const h1 = *(const XXH128_hash_t*)h128_1; + XXH128_hash_t const h2 = *(const XXH128_hash_t*)h128_2; + int const hcmp = (h1.high64 > h2.high64) - (h2.high64 > h1.high64); + /* note : bets that, in most cases, hash values are different */ + if (hcmp) return hcmp; + return (h1.low64 > h2.low64) - (h2.low64 > h1.low64); +} + + +/*====== Canonical representation ======*/ +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API void +XXH128_canonicalFromHash(XXH_NOESCAPE XXH128_canonical_t* dst, XXH128_hash_t hash) +{ + XXH_STATIC_ASSERT(sizeof(XXH128_canonical_t) == sizeof(XXH128_hash_t)); + if (XXH_CPU_LITTLE_ENDIAN) { + hash.high64 = XXH_swap64(hash.high64); + hash.low64 = XXH_swap64(hash.low64); + } + XXH_memcpy(dst, &hash.high64, sizeof(hash.high64)); + XXH_memcpy((char*)dst + sizeof(hash.high64), &hash.low64, sizeof(hash.low64)); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH128_hash_t +XXH128_hashFromCanonical(XXH_NOESCAPE const XXH128_canonical_t* src) +{ + XXH128_hash_t h; + h.high64 = XXH_readBE64(src); + h.low64 = XXH_readBE64(src->digest + 8); + return h; +} + + + +/* ========================================== + * Secret generators + * ========================================== + */ +#define XXH_MIN(x, y) (((x) > (y)) ? (y) : (x)) + +XXH_FORCE_INLINE void XXH3_combine16(void* dst, XXH128_hash_t h128) +{ + XXH_writeLE64( dst, XXH_readLE64(dst) ^ h128.low64 ); + XXH_writeLE64( (char*)dst+8, XXH_readLE64((char*)dst+8) ^ h128.high64 ); +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API XXH_errorcode +XXH3_generateSecret(XXH_NOESCAPE void* secretBuffer, size_t secretSize, XXH_NOESCAPE const void* customSeed, size_t customSeedSize) +{ +#if (XXH_DEBUGLEVEL >= 1) + XXH_ASSERT(secretBuffer != NULL); + XXH_ASSERT(secretSize >= XXH3_SECRET_SIZE_MIN); +#else + /* production mode, assert() are disabled */ + if (secretBuffer == NULL) return XXH_ERROR; + if (secretSize < XXH3_SECRET_SIZE_MIN) return XXH_ERROR; +#endif + + if (customSeedSize == 0) { + customSeed = XXH3_kSecret; + customSeedSize = XXH_SECRET_DEFAULT_SIZE; + } +#if (XXH_DEBUGLEVEL >= 1) + XXH_ASSERT(customSeed != NULL); +#else + if (customSeed == NULL) return XXH_ERROR; +#endif + + /* Fill secretBuffer with a copy of customSeed - repeat as needed */ + { size_t pos = 0; + while (pos < secretSize) { + size_t const toCopy = XXH_MIN((secretSize - pos), customSeedSize); + memcpy((char*)secretBuffer + pos, customSeed, toCopy); + pos += toCopy; + } } + + { size_t const nbSeg16 = secretSize / 16; + size_t n; + XXH128_canonical_t scrambler; + XXH128_canonicalFromHash(&scrambler, XXH128(customSeed, customSeedSize, 0)); + for (n=0; n<nbSeg16; n++) { + XXH128_hash_t const h128 = XXH128(&scrambler, sizeof(scrambler), n); + XXH3_combine16((char*)secretBuffer + n*16, h128); + } + /* last segment */ + XXH3_combine16((char*)secretBuffer + secretSize - 16, XXH128_hashFromCanonical(&scrambler)); + } + return XXH_OK; +} + +/*! @ingroup XXH3_family */ +XXH_PUBLIC_API void +XXH3_generateSecret_fromSeed(XXH_NOESCAPE void* secretBuffer, XXH64_hash_t seed) +{ + XXH_ALIGN(XXH_SEC_ALIGN) xxh_u8 secret[XXH_SECRET_DEFAULT_SIZE]; + XXH3_initCustomSecret(secret, seed); + XXH_ASSERT(secretBuffer != NULL); + memcpy(secretBuffer, secret, XXH_SECRET_DEFAULT_SIZE); +} + + + +/* Pop our optimization override from above */ +#if XXH_VECTOR == XXH_AVX2 /* AVX2 */ \ + && defined(__GNUC__) && !defined(__clang__) /* GCC, not Clang */ \ + && defined(__OPTIMIZE__) && XXH_SIZE_OPT <= 0 /* respect -O0 and -Os */ +# pragma GCC pop_options +#endif + +#endif /* XXH_NO_LONG_LONG */ + +#endif /* XXH_NO_XXH3 */ + +/*! + * @} + */ +#endif /* XXH_IMPLEMENTATION */ + + +#if defined (__cplusplus) +} /* extern "C" */ +#endif diff --git a/llama.cpp/examples/gguf-hash/gguf-hash.cpp b/llama.cpp/examples/gguf-hash/gguf-hash.cpp new file mode 100644 index 0000000..9523ec1 --- /dev/null +++ b/llama.cpp/examples/gguf-hash/gguf-hash.cpp @@ -0,0 +1,694 @@ +#include "ggml.h" +#include "gguf.h" + +#include <cstdlib> /* abort() */ +#include <cstddef> +#include <cstdio> +#include <string> +#include <stdexcept> +#include <algorithm> +#include <cstring> + +#include <sstream> +#include <fstream> + +#ifdef __cplusplus +extern "C" { +#endif + +#include "xxhash/xxhash.h" +#include "sha1/sha1.h" +#include "sha256/sha256.h" + +#ifdef __cplusplus +} +#endif + + +// uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp') +#define UUID_NAMESPACE_LLAMA_CPP "ef001206-dadc-5f6d-a15f-3359e577d4e5" +#define UUID_NAMESPACE_LLAMA_CPP_HEX 0xef, 0x00, 0x12, 0x06, 0xda, 0xdc, 0x5f, 0x6d, 0xa1, 0x5f, 0x33, 0x59, 0xe5, 0x77, 0xd4, 0xe5 + + +#define HASH_TYPE_SHA256_STR "sha256" +#define HASH_TYPE_SHA1_STR "sha1" +#define HASH_TYPE_XXH64_STR "xxh64" +#define HASH_TYPE_UUID_STR "uuid" + + +typedef enum { + HASH_EXIT_SUCCESS = 0, // All hash has been generated or validated + HASH_EXIT_FAILURE = 1, // Generic Failure + HASH_EXIT_MISMATCH = 2, // Hash mismatched during validation + HASH_EXIT_MANIFEST_MISSING_ENTRY = 3, // Hash attempted validation but missing entry in manifest + HASH_EXIT_MANIFEST_UNKNOWN_HASH = 4, // Manifest is present, but we do not know any hash format within it + HASH_EXIT_MANIFEST_FILE_ERROR = 5 // Manifest is either missing or not a known format +} hash_exit_code_t; + + +typedef enum { + HASH_MANIFEST_NOT_FOUND, + HASH_MANIFEST_MISMATCH, + HASH_MANIFEST_OK, +} hash_manifest_result_t; + + +struct hash_params { + std::string input; + bool xxh64 = false; + bool sha1 = false; + bool sha256 = false; + bool uuid = false; + + bool no_layer = false; + + bool manifest_is_usable = false; + std::string manifest_file; +}; + +struct manifest_check_params { + bool xxh64 = false; + bool sha1 = false; + bool sha256 = false; + bool uuid = false; +}; + +static char const * hash_manifest_result_to_str(hash_manifest_result_t value) { + switch (value) { + case HASH_MANIFEST_NOT_FOUND: return "Not Found"; + case HASH_MANIFEST_MISMATCH: return "Mismatch"; + case HASH_MANIFEST_OK: return "Ok"; + } + return "?"; +} + +static char const * hash_exit_code_to_str(hash_exit_code_t value) { + switch (value) { + case HASH_EXIT_SUCCESS: return "Success"; + case HASH_EXIT_FAILURE: return "Failure"; + case HASH_EXIT_MISMATCH: return "Mismatch"; + case HASH_EXIT_MANIFEST_MISSING_ENTRY: return "Manifest Missing Entry"; + case HASH_EXIT_MANIFEST_UNKNOWN_HASH: return "Manifest Unknown Hash"; + case HASH_EXIT_MANIFEST_FILE_ERROR: return "Manifest File Error"; + } + return "?"; +} + +static void hash_print_usage(const char * executable) { + const hash_params default_params; + printf("\n"); + printf("usage: %s [options] GGUF_IN\n", executable); + printf("\n"); + printf("Hash a GGUF file"); + printf("\n"); + printf("options:\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" --xxh64 use xxh64 hash\n"); + printf(" --sha1 use sha1 hash\n"); + printf(" --sha256 use sha256 hash\n"); + printf(" --all use all hash\n"); + printf(" --no-layer exclude per layer hash\n"); + printf(" --uuid generate UUIDv5 ID\n"); + printf(" -c, --check <manifest> verify against a manifest\n"); + printf("\n"); +} + +static void hash_params_parse_ex(int argc, const char ** argv, hash_params & params) { + std::string arg; + bool invalid_param = false; + const std::string arg_prefix = "--"; + + int arg_idx = 1; + for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { + arg = argv[arg_idx]; + if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { + std::replace(arg.begin(), arg.end(), '_', '-'); + } + + bool arg_found = false; + if (arg == "-h" || arg == "--help") { + hash_print_usage(argv[0]); + exit(0); + } + + if (arg == "--xxh64") { + arg_found = true; + params.xxh64 = true; + } + + if (arg == "--sha1") { + arg_found = true; + params.sha1 = true; + } + + if (arg == "--uuid") { + arg_found = true; + params.uuid = true; + } + + if (arg == "--sha256") { + arg_found = true; + params.sha256 = true; + } + + if (arg == "--all") { + arg_found = true; + params.sha256 = true; + params.sha1 = true; + params.xxh64 = true; + } + + if (arg == "--no-layer") { + arg_found = true; + params.no_layer = true; + } + + if (arg == "-c" || arg == "--check") { + if (++arg_idx >= argc) { + invalid_param = true; + break; + } + arg_found = true; + params.manifest_file = argv[arg_idx]; + } + + if (!arg_found) { + throw std::invalid_argument("error: unknown argument: " + arg); + } + } + + if (invalid_param) { + throw std::invalid_argument("error: invalid parameter for argument:" + arg); + } + + if (argc - arg_idx < 1) { + throw std::invalid_argument("error: bad arguments"); + } + + params.input = argv[arg_idx++]; +} + +static bool hash_params_parse(int argc, const char ** argv, hash_params & params) { + bool result = true; + try { + hash_params_parse_ex(argc, argv, params); + } + catch (const std::invalid_argument & ex) { + fprintf(stderr, "%s\n", ex.what()); + hash_print_usage(argv[0]); + exit(EXIT_FAILURE); + } + return result; +} + +static bool manifest_type(const std::string & manifest_file, manifest_check_params & manifest_check) { + if (manifest_file.empty()) { + return false; + } + + std::ifstream file(manifest_file); + if (!file.is_open()) { + return false; + } + + std::string manifest_entry_line; + while (getline(file, manifest_entry_line)) { + // hash_type_str hash_str tensor_name + // e.g. 'xxh64 f66e9cd66a4396a0 test.gguf:tensor_0' + std::istringstream line_stream(manifest_entry_line); + std::string file_hash_type; + if (line_stream >> file_hash_type) { + if (file_hash_type == HASH_TYPE_SHA256_STR) { + manifest_check.sha256 = true; + } else if (file_hash_type == HASH_TYPE_SHA1_STR) { + manifest_check.sha1 = true; + } else if (file_hash_type == HASH_TYPE_XXH64_STR) { + manifest_check.xxh64 = true; + } else if (file_hash_type == HASH_TYPE_UUID_STR) { + manifest_check.uuid = true; + } + } + } + + return true; +} + +static hash_manifest_result_t manifest_verify(const std::string& manifest_file, const std::string& hash_type_str, const std::string& hash_str, const std::string& tensor_name) { + if (manifest_file.empty()) { + return HASH_MANIFEST_NOT_FOUND; + } + + std::ifstream file(manifest_file); + if (!file.is_open()) { + return HASH_MANIFEST_NOT_FOUND; + } + + std::string manifest_entry_line; + while (getline(file, manifest_entry_line)) { + std::istringstream line_stream(manifest_entry_line); + std::string file_hash_type; + std::string file_hash; + std::string file_tensor_name; + if (line_stream >> file_hash_type >> file_hash >> file_tensor_name) { + // Line parsed. Check hash validity + + if (file_hash_type != hash_type_str) { + continue; + } + + if (file_tensor_name != tensor_name) { + continue; + } + + return (file_hash == hash_str) ? HASH_MANIFEST_OK : HASH_MANIFEST_MISMATCH; + } + } + + return HASH_MANIFEST_NOT_FOUND; +} + +static void generate_uuidv5(const unsigned char sha1_digest[20], unsigned char uuid[16]) { + // Ref: https://www.rfc-editor.org/rfc/rfc9562.html#section-5.5 + // Assumes that digest was processed correctly with the expected namespace + for (int i = 0; i < 16; i++) { + uuid[i] = sha1_digest[i]; + } + + // Set bits corresponding to UUID ver 5 + uuid[ 6] &= ~(0xF << 4); + uuid[ 6] |= (5 << 4); + + // Set bits corresponding to UUID variant 0b10XX + uuid[ 8] &= ~(0xc << 4); + uuid[ 8] |= (0x8 << 4); +} + +static hash_exit_code_t gguf_hash(const hash_params & hash_params) { + const std::string & fname = hash_params.input; + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + // xxh64 init + XXH64_state_t* xxh64_model_hash_state = NULL; + if (hash_params.xxh64) { + xxh64_model_hash_state = XXH64_createState(); + if (xxh64_model_hash_state==NULL) { + abort(); + } + + XXH64_hash_t const seed = 0; + if (XXH64_reset(xxh64_model_hash_state, seed) == XXH_ERROR) { + abort(); + } + } + + // sha1 init + SHA1_CTX sha1_model_hash_ctx; + if (hash_params.sha1) { + SHA1Init(&sha1_model_hash_ctx); + } + + // sha256 init + sha256_t sha256_model_hash_ctx; + if (hash_params.sha256) { + sha256_init(&sha256_model_hash_ctx); + } + + // sha1 for uuid init + SHA1_CTX sha1_for_uuid_ctx; + if (hash_params.uuid) { + unsigned char const uuidv5_namespace[] = {UUID_NAMESPACE_LLAMA_CPP_HEX}; + SHA1Init(&sha1_for_uuid_ctx); + SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)uuidv5_namespace, sizeof(uuidv5_namespace)); + } + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + const int n_tensors = gguf_get_n_tensors(ctx); + bool tensor_layer_in_manifest = false; + bool model_in_manifest = false; + bool tensor_layer_has_mismatch = false; + bool model_has_mismatch = false; + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx, i); + struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + auto n_bytes = ggml_nbytes(cur); + auto *raw_data = cur->data; + const std::string tensor_layer_name = fname + ":" + name; + + if (hash_params.xxh64) { + + if (!hash_params.no_layer) { + // Per Layer Hash + XXH64_hash_t hash = XXH64(raw_data, n_bytes, 0); + + char hex_result[17]; + for (int offset = 0; offset < 8; offset++) { + unsigned int shift_bits_by = (8 * (8 - offset - 1)); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + tensor_layer_in_manifest = true; + tensor_layer_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + tensor_layer_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, tensor_layer_name.c_str()); + } + } + + // Overall Model Hash + if (XXH64_update(xxh64_model_hash_state, raw_data, n_bytes) == XXH_ERROR) abort(); + } + + if (hash_params.sha1) { + + if (!hash_params.no_layer) { + // Per Layer Hash + char result[21]; // sha1 outputs 20 bytes + SHA1( result, (const char *)raw_data, n_bytes); + + char hex_result[41] = {0}; + for (int offset = 0; offset < 20; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + tensor_layer_in_manifest = true; + tensor_layer_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + tensor_layer_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, tensor_layer_name.c_str()); + } + } + + // Overall Model Hash + SHA1Update( &sha1_model_hash_ctx, (unsigned char const *)raw_data, n_bytes); + } + + if (hash_params.sha256) { + + if (!hash_params.no_layer) { + // Per Layer Hash + unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes + sha256_hash((unsigned char*) result, (const unsigned char *)raw_data, n_bytes); + + char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; + for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + tensor_layer_in_manifest = true; + tensor_layer_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + tensor_layer_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, tensor_layer_name.c_str()); + } + } + + // Overall Model Hash + sha256_update( &sha256_model_hash_ctx, (unsigned char const *)raw_data, n_bytes); + } + + if (hash_params.uuid) { + SHA1Update( &sha1_for_uuid_ctx, (unsigned char const *)raw_data, n_bytes); + } + } + + if (hash_params.xxh64) { + XXH64_hash_t const hash = XXH64_digest(xxh64_model_hash_state); + + char hex_result[17]; + for (int offset = 0; offset < 8; offset++) { + unsigned int shift_bits_by = (8 * (8 - offset - 1)); + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_XXH64_STR, hex_result, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_XXH64_STR, hex_result, fname.c_str()); + } + } + + if (hash_params.sha1) { + unsigned char result[21]; + SHA1Final(result, &sha1_model_hash_ctx); + + char hex_result[41]; + for (int offset = 0; offset < 20; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA1_STR, hex_result, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA1_STR, hex_result, fname.c_str()); + } + } + + if (hash_params.sha256) { + unsigned char result[SHA256_DIGEST_SIZE]; // sha256 outputs 32 bytes + sha256_final( &sha256_model_hash_ctx, result); + + char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0}; + for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) { + snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff); + } + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, hex_result, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_SHA256_STR, hex_result, fname.c_str()); + } + } + + if (hash_params.uuid) { + unsigned char result[21]; + SHA1Final(result, &sha1_for_uuid_ctx); + + unsigned char uuid[16]; + generate_uuidv5(result, uuid); + + char string_buffer[37] = {0}; + snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], uuid[6], uuid[7], + uuid[8], uuid[9], uuid[10], uuid[11], + uuid[12], uuid[13], uuid[14], uuid[15]); + + if (hash_params.manifest_is_usable) { + hash_manifest_result_t verify_result = manifest_verify(hash_params.manifest_file, HASH_TYPE_SHA256_STR, string_buffer, fname); + + switch (verify_result) { + case HASH_MANIFEST_NOT_FOUND: + break; + case HASH_MANIFEST_MISMATCH: + model_in_manifest = true; + model_has_mismatch = true; + break; + case HASH_MANIFEST_OK: + model_in_manifest = true; + break; + } + + printf("%-8s %-s %s - %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str(), hash_manifest_result_to_str(verify_result)); + } else { + printf("%-8s %-s %s\n", HASH_TYPE_UUID_STR, string_buffer, fname.c_str()); + } + } + + + ggml_free(ctx_data); + gguf_free(ctx); + + + if (hash_params.manifest_is_usable) { + // In hash verification mode + + if (!model_in_manifest) { + // model missing in manifest? + + // Check tensor layer... + if (!tensor_layer_in_manifest) { + // Still missing? Maybe we are reading the wrong manifest. + return HASH_EXIT_MANIFEST_MISSING_ENTRY; + } + + if (tensor_layer_has_mismatch) { + // Per tensor check found error + return HASH_EXIT_FAILURE; + } + + // All per tensor layer checks passed? Sounds good enough. + return HASH_EXIT_SUCCESS; + } + + // Overall model check passed, but let's check per layer just in case + // If missing, we don't care too much as the overall model checked + if (tensor_layer_in_manifest && tensor_layer_has_mismatch) { + return HASH_EXIT_FAILURE; + } + + if (model_has_mismatch) { + // model has failed hash somewhere in the model + return HASH_EXIT_FAILURE; + } + + // All checks appears to be fine + return HASH_EXIT_SUCCESS; + } + + // In hash generation mode + return HASH_EXIT_SUCCESS; +} + +int main(int argc, const char ** argv) { + hash_params params; + manifest_check_params manifest_check; + hash_params_parse(argc, argv, params); + + if (!params.manifest_file.empty()) { + if (!manifest_type(params.manifest_file, manifest_check)) { + printf("ERROR cannot open manifest %s", params.manifest_file.c_str()); + return HASH_EXIT_MANIFEST_FILE_ERROR; + } + + if (!manifest_check.sha256 && !manifest_check.sha1 && !manifest_check.xxh64 && !manifest_check.uuid) { + printf("ERROR manifest does not have any known hash format in %s", params.manifest_file.c_str()); + return HASH_EXIT_MANIFEST_UNKNOWN_HASH; + } + + printf("manifest %s", params.manifest_file.c_str()); + + if (manifest_check.sha256) { + printf(" sha256"); + } + + if (manifest_check.sha1) { + printf(" sha1"); + } + + if (manifest_check.xxh64) { + printf(" xxh64"); + } + + if (manifest_check.uuid) { + printf(" uuid"); + } + + printf("\n"); + + // Autoselect the highest security hash if manifest is provided but + // the user has not specifically defined the hash they care about + if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) { + // User has not selected a specific value, pick most secure hash + if (manifest_check.sha256) { + params.sha256 = true; + } else if (manifest_check.sha1) { + params.sha1 = true; + } else if (manifest_check.xxh64) { + params.xxh64 = true; + } else if (manifest_check.uuid) { + params.uuid = true; + } + } + + params.manifest_is_usable = true; + } + + // By default if no swich argument provided, assume xxh64 + if (!params.xxh64 && !params.sha1 && !params.uuid && !params.sha256) { + params.xxh64 = true; + } + + hash_exit_code_t exit_code = gguf_hash(params); + + if (params.manifest_is_usable) { + printf("\nVerification results for %s - %s\n", params.manifest_file.c_str(), hash_exit_code_to_str(exit_code)); + } + + return exit_code; +} diff --git a/llama.cpp/examples/gguf/CMakeLists.txt b/llama.cpp/examples/gguf/CMakeLists.txt new file mode 100644 index 0000000..fb04eb8 --- /dev/null +++ b/llama.cpp/examples/gguf/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-gguf) +add_executable(${TARGET} gguf.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE ggml ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/gguf/gguf.cpp b/llama.cpp/examples/gguf/gguf.cpp new file mode 100644 index 0000000..499cfac --- /dev/null +++ b/llama.cpp/examples/gguf/gguf.cpp @@ -0,0 +1,270 @@ +#include "ggml.h" +#include "gguf.h" + +#include <cstdio> +#include <string> +#include <sstream> +#include <vector> + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +template <typename T> +static std::string to_string(const T & val) { + std::stringstream ss; + ss << val; + return ss.str(); +} + +static bool gguf_ex_write(const std::string & fname) { + struct gguf_context * ctx = gguf_init_empty(); + + gguf_set_val_u8 (ctx, "some.parameter.uint8", 0x12); + gguf_set_val_i8 (ctx, "some.parameter.int8", -0x13); + gguf_set_val_u16 (ctx, "some.parameter.uint16", 0x1234); + gguf_set_val_i16 (ctx, "some.parameter.int16", -0x1235); + gguf_set_val_u32 (ctx, "some.parameter.uint32", 0x12345678); + gguf_set_val_i32 (ctx, "some.parameter.int32", -0x12345679); + gguf_set_val_f32 (ctx, "some.parameter.float32", 0.123456789f); + gguf_set_val_u64 (ctx, "some.parameter.uint64", 0x123456789abcdef0ull); + gguf_set_val_i64 (ctx, "some.parameter.int64", -0x123456789abcdef1ll); + gguf_set_val_f64 (ctx, "some.parameter.float64", 0.1234567890123456789); + gguf_set_val_bool(ctx, "some.parameter.bool", true); + gguf_set_val_str (ctx, "some.parameter.string", "hello world"); + + gguf_set_arr_data(ctx, "some.parameter.arr.i16", GGUF_TYPE_INT16, std::vector<int16_t>{ 1, 2, 3, 4, }.data(), 4); + gguf_set_arr_data(ctx, "some.parameter.arr.f32", GGUF_TYPE_FLOAT32, std::vector<float>{ 3.145f, 2.718f, 1.414f, }.data(), 3); + gguf_set_arr_str (ctx, "some.parameter.arr.str", std::vector<const char *>{ "hello", "world", "!" }.data(), 3); + + struct ggml_init_params params = { + /*.mem_size =*/ 128ull*1024ull*1024ull, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx_data = ggml_init(params); + + const int n_tensors = 10; + + // tensor infos + for (int i = 0; i < n_tensors; ++i) { + const std::string name = "tensor_" + to_string(i); + + int64_t ne[GGML_MAX_DIMS] = { 1 }; + int32_t n_dims = rand() % GGML_MAX_DIMS + 1; + + for (int j = 0; j < n_dims; ++j) { + ne[j] = rand() % 10 + 1; + } + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, GGML_TYPE_F32, n_dims, ne); + ggml_set_name(cur, name.c_str()); + + { + float * data = (float *) cur->data; + for (int j = 0; j < ggml_nelements(cur); ++j) { + data[j] = 100 + i; + } + } + + gguf_add_tensor(ctx, cur); + } + + gguf_write_to_file(ctx, fname.c_str(), false); + + printf("%s: wrote file '%s;\n", __func__, fname.c_str()); + + ggml_free(ctx_data); + gguf_free(ctx); + + return true; +} + +// just read tensor info +static bool gguf_ex_read_0(const std::string & fname) { + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + + if (!ctx) { + fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname.c_str()); + return false; + } + + printf("%s: version: %d\n", __func__, gguf_get_version(ctx)); + printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = gguf_get_n_kv(ctx); + + printf("%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ctx, i); + + printf("%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + + // find kv string + { + const char * findkey = "some.parameter.string"; + + const int keyidx = gguf_find_key(ctx, findkey); + if (keyidx == -1) { + printf("%s: find key: %s not found.\n", __func__, findkey); + } else { + const char * key_value = gguf_get_val_str(ctx, keyidx); + printf("%s: find key: %s found, kv[%d] value = %s\n", __func__, findkey, keyidx, key_value); + } + } + + // tensor info + { + const int n_tensors = gguf_get_n_tensors(ctx); + + printf("%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); + const size_t offset = gguf_get_tensor_offset(ctx, i); + + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu\n", __func__, i, name, size, offset); + } + } + + gguf_free(ctx); + + return true; +} + +// read and create ggml_context containing the tensors and their data +static bool gguf_ex_read_1(const std::string & fname, bool check_data) { + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + + printf("%s: version: %d\n", __func__, gguf_get_version(ctx)); + printf("%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + printf("%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = gguf_get_n_kv(ctx); + + printf("%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ctx, i); + + printf("%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + + // tensor info + { + const int n_tensors = gguf_get_n_tensors(ctx); + + printf("%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ctx, i); + const size_t size = gguf_get_tensor_size (ctx, i); + const size_t offset = gguf_get_tensor_offset(ctx, i); + const auto type = gguf_get_tensor_type (ctx, i); + + const char * type_name = ggml_type_name(type); + const size_t type_size = ggml_type_size(type); + const size_t n_elements = size / type_size; + + printf("%s: tensor[%d]: name = %s, size = %zu, offset = %zu, type = %s, n_elts = %zu\n", __func__, i, name, size, offset, type_name, n_elements); + } + } + + // data + { + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + printf("%s: reading tensor %d data\n", __func__, i); + + const char * name = gguf_get_tensor_name(ctx, i); + + struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + + printf("%s: tensor[%d]: n_dims = %d, ne = (%d, %d, %d, %d), name = %s, data = %p\n", + __func__, i, ggml_n_dims(cur), int(cur->ne[0]), int(cur->ne[1]), int(cur->ne[2]), int(cur->ne[3]), cur->name, cur->data); + + // print first 10 elements + const float * data = (const float *) cur->data; + + printf("%s data[:10] : ", name); + for (int j = 0; j < MIN(10, ggml_nelements(cur)); ++j) { + printf("%f ", data[j]); + } + printf("\n\n"); + + // check data + if (check_data) { + const float * data = (const float *) cur->data; + for (int j = 0; j < ggml_nelements(cur); ++j) { + if (data[j] != 100 + i) { + fprintf(stderr, "%s: tensor[%d], data[%d]: found %f, expected %f\n", __func__, i, j, data[j], float(100 + i)); + gguf_free(ctx); + return false; + } + } + } + } + } + + printf("%s: ctx_data size: %zu\n", __func__, ggml_get_mem_size(ctx_data)); + + ggml_free(ctx_data); + gguf_free(ctx); + + return true; +} + +int main(int argc, char ** argv) { + if (argc < 3) { + printf("usage: %s data.gguf r|w [n]\n", argv[0]); + printf("r: read data.gguf file\n"); + printf("w: write data.gguf file\n"); + printf("n: no check of tensor data\n"); + return -1; + } + bool check_data = true; + if (argc == 4) { + check_data = false; + } + + srand(123456); + + const std::string fname(argv[1]); + const std::string mode (argv[2]); + + GGML_ASSERT((mode == "r" || mode == "w") && "mode must be r or w"); + + if (mode == "w") { + GGML_ASSERT(gguf_ex_write(fname) && "failed to write gguf file"); + } else if (mode == "r") { + GGML_ASSERT(gguf_ex_read_0(fname) && "failed to read gguf file"); + GGML_ASSERT(gguf_ex_read_1(fname, check_data) && "failed to read gguf file"); + } + + return 0; +} diff --git a/llama.cpp/examples/idle/CMakeLists.txt b/llama.cpp/examples/idle/CMakeLists.txt new file mode 100644 index 0000000..d5018fe --- /dev/null +++ b/llama.cpp/examples/idle/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-idle) +add_executable(${TARGET} idle.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/llama.cpp/examples/idle/README.md b/llama.cpp/examples/idle/README.md new file mode 100644 index 0000000..d38e607 --- /dev/null +++ b/llama.cpp/examples/idle/README.md @@ -0,0 +1,3 @@ +# llama.cpp/example/idle + +https://github.com/ggml-org/llama.cpp/pull/17766 diff --git a/llama.cpp/examples/idle/idle.cpp b/llama.cpp/examples/idle/idle.cpp new file mode 100644 index 0000000..0004271 --- /dev/null +++ b/llama.cpp/examples/idle/idle.cpp @@ -0,0 +1,110 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include <cmath> +#include <cstdio> +#include <cstring> +#include <string> +#include <thread> +#include <vector> + +static void print_usage(int /*argc*/, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-ngl n_gpu_layers]\n", argv[0]); + printf("\n"); +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) { + return 1; + } + + common_init(); + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = common_model_params_to_llama(params); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + + if (model == NULL) { + LOG_ERR("%s: error: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // we need just a dummy token to evaluate + std::vector<llama_token> prompt_tokens(1, llama_vocab_bos(vocab)); + + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = 512; + ctx_params.n_batch = 512; + ctx_params.no_perf = false; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + + const int n_iters = 3; + + // warm-up + llama_decode(ctx, batch); + llama_memory_clear(llama_get_memory(ctx), true); + llama_synchronize(ctx); + + for (int64_t t_pause_ms = 0; t_pause_ms <= 4000; t_pause_ms += 800) { + double t_sum_us = 0.0; + double t_sum2_us = 0.0; + + for (int i = 0; i < n_iters; i++) { + // this pause is important - it simulates "idle GPU" + std::this_thread::sleep_for(std::chrono::milliseconds(t_pause_ms)); + + const int64_t t_start_us = llama_time_us(); + + // this should take constant time + llama_decode(ctx, batch); + llama_synchronize(ctx); + + const int64_t t_end_us = llama_time_us(); + + const double t_cur_us = t_end_us - t_start_us; + +#if 1 + // print individual decode times + printf(" - decode time: %8.2f ms\n", t_cur_us / 1000); +#endif + + t_sum_us += t_cur_us; + t_sum2_us += t_cur_us * t_cur_us; + + llama_memory_clear(llama_get_memory(ctx), true); + llama_synchronize(ctx); // just in case + } + + const double t_avg_us = t_sum_us / n_iters; + const double t_dev_us = sqrt((t_sum2_us / (n_iters - 1)) - (t_avg_us * t_avg_us * n_iters) / (n_iters - 1)); + + printf("iters: %4d, pause: %5d ms, avg decode time: %8.2f +/- %4.2f ms\n", n_iters, (int) t_pause_ms, t_avg_us / 1000, t_dev_us / 1000); + fflush(stdout); + } + + llama_free(ctx); + llama_model_free(model); + + return 0; +} diff --git a/llama.cpp/examples/json_schema_pydantic_example.py b/llama.cpp/examples/json_schema_pydantic_example.py new file mode 100644 index 0000000..19c0bdb --- /dev/null +++ b/llama.cpp/examples/json_schema_pydantic_example.py @@ -0,0 +1,82 @@ +# Usage: +#! ./llama-server -m some-model.gguf & +#! pip install pydantic +#! python json_schema_pydantic_example.py + +from pydantic import BaseModel, Field, TypeAdapter +from annotated_types import MinLen +from typing import Annotated, List, Optional +import json, requests + +if True: + + def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs): + ''' + Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support + (llama.cpp server, llama-cpp-python, Anyscale / Together...) + + The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below) + ''' + response_format = None + type_adapter = None + + if response_model: + type_adapter = TypeAdapter(response_model) + schema = type_adapter.json_schema() + messages = [{ + "role": "system", + "content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}" + }] + messages + response_format={"type": "json_object", "schema": schema} + + data = requests.post(endpoint, headers={"Content-Type": "application/json"}, + json=dict(messages=messages, response_format=response_format, **kwargs)).json() + if 'error' in data: + raise Exception(data['error']['message']) + + content = data["choices"][0]["message"]["content"] + return type_adapter.validate_json(content) if type_adapter else content + +else: + + # This alternative branch uses Instructor + OpenAI client lib. + # Instructor support streamed iterable responses, retry & more. + # (see https://python.useinstructor.com/) + #! pip install instructor openai + import instructor, openai + client = instructor.patch( + openai.OpenAI(api_key="123", base_url="http://localhost:8080"), + mode=instructor.Mode.JSON_SCHEMA) + create_completion = client.chat.completions.create + + +if __name__ == '__main__': + + class QAPair(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema + question: str + concise_answer: str + justification: str + stars: Annotated[int, Field(ge=1, le=5)] + + class PyramidalSummary(BaseModel): + class Config: + extra = 'forbid' # triggers additionalProperties: false in the JSON schema + title: str + summary: str + question_answers: Annotated[List[QAPair], MinLen(2)] + sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]] + + print("# Summary\n", create_completion( + model="...", + response_model=PyramidalSummary, + messages=[{ + "role": "user", + "content": f""" + You are a highly efficient corporate document summarizer. + Create a pyramidal summary of an imaginary internal document about our company processes + (starting high-level, going down to each sub sections). + Keep questions short, and answers even shorter (trivia / quizz style). + """ + }])) diff --git a/llama.cpp/examples/json_schema_to_grammar.py b/llama.cpp/examples/json_schema_to_grammar.py new file mode 100755 index 0000000..9fc90a3 --- /dev/null +++ b/llama.cpp/examples/json_schema_to_grammar.py @@ -0,0 +1,837 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import itertools +import json +import re +import sys +from typing import Any, List, Optional, Set, Tuple, Union + +def _build_repetition(item_rule, min_items, max_items, separator_rule=None): + + if max_items == 0: + return "" + + if min_items == 0 and max_items == 1: + return f'{item_rule}?' + + if not separator_rule: + if min_items == 1 and max_items is None: + return f'{item_rule}+' + elif min_items == 0 and max_items is None: + return f'{item_rule}*' + else: + return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' + + result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) + return f'({result})?' if min_items == 0 else result + +def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): + has_min = min_value != None + has_max = max_value != None + + def digit_range(from_char: str, to_char: str): + out.append("[") + if from_char == to_char: + out.append(from_char) + else: + out.append(from_char) + out.append("-") + out.append(to_char) + out.append("]") + + def more_digits(min_digits: int, max_digits: int): + out.append("[0-9]") + if min_digits == max_digits and min_digits == 1: + return + out.append("{") + out.append(str(min_digits)) + if max_digits != min_digits: + out.append(",") + if max_digits != sys.maxsize: + out.append(str(max_digits)) + out.append("}") + + def uniform_range(from_str: str, to_str: str): + i = 0 + while i < len(from_str) and from_str[i] == to_str[i]: + i += 1 + if i > 0: + out.append("\"") + out.append(from_str[:i]) + out.append("\"") + if i < len(from_str): + if i > 0: + out.append(" ") + sub_len = len(from_str) - i - 1 + if sub_len > 0: + from_sub = from_str[i+1:] + to_sub = to_str[i+1:] + sub_zeros = "0" * sub_len + sub_nines = "9" * sub_len + + to_reached = False + out.append("(") + if from_sub == sub_zeros: + digit_range(from_str[i], chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + else: + out.append("[") + out.append(from_str[i]) + out.append("] ") + out.append("(") + uniform_range(from_sub, sub_nines) + out.append(")") + if ord(from_str[i]) < ord(to_str[i]) - 1: + out.append(" | ") + if to_sub == sub_nines: + digit_range(chr(ord(from_str[i]) + 1), to_str[i]) + to_reached = True + else: + digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1)) + out.append(" ") + more_digits(sub_len, sub_len) + if not to_reached: + out.append(" | ") + digit_range(to_str[i], to_str[i]) + out.append(" ") + uniform_range(sub_zeros, to_sub) + out.append(")") + else: + out.append("[") + out.append(from_str[i]) + out.append("-") + out.append(to_str[i]) + out.append("]") + + if has_min and has_max: + if min_value < 0 and max_value < 0: + out.append("\"-\" (") + _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) + out.append(")") + return + + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True) + out.append(") | ") + min_value = 0 + + min_s = str(min_value) + max_s = str(max_value) + min_digits = len(min_s) + max_digits = len(max_s) + + for digits in range(min_digits, max_digits): + uniform_range(min_s, "9" * digits) + min_s = "1" + "0" * digits + out.append(" | ") + uniform_range(min_s, max_s) + return + + less_decimals = max(decimals_left - 1, 1) + + if has_min: + if min_value < 0: + out.append("\"-\" (") + _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) + out.append(") | [0] | [1-9] ") + more_digits(0, decimals_left - 1) + elif min_value == 0: + if top_level: + out.append("[0] | [1-9] ") + more_digits(0, less_decimals) + else: + more_digits(1, decimals_left) + elif min_value <= 9: + c = str(min_value) + range_start = '1' if top_level else '0' + if c > range_start: + digit_range(range_start, chr(ord(c) - 1)) + out.append(" ") + more_digits(1, less_decimals) + out.append(" | ") + digit_range(c, "9") + out.append(" ") + more_digits(0, less_decimals) + else: + min_s = str(min_value) + length = len(min_s) + c = min_s[0] + + if c > "1": + digit_range("1" if top_level else "0", chr(ord(c) - 1)) + out.append(" ") + more_digits(length, less_decimals) + out.append(" | ") + digit_range(c, c) + out.append(" (") + _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False) + out.append(")") + if c < "9": + out.append(" | ") + digit_range(chr(ord(c) + 1), "9") + out.append(" ") + more_digits(length - 1, less_decimals) + return + + if has_max: + if max_value >= 0: + if top_level: + out.append("\"-\" [1-9] ") + more_digits(0, less_decimals) + out.append(" | ") + _generate_min_max_int(0, max_value, out, decimals_left, top_level=True) + else: + out.append("\"-\" (") + _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False) + out.append(")") + return + + raise RuntimeError("At least one of min_value or max_value must be set") + +class BuiltinRule: + def __init__(self, content: str, deps: list | None = None): + self.content = content + self.deps = deps or [] + +# Constraining spaces to prevent model "running away". +SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}' + +PRIMITIVE_RULES = { + 'boolean' : BuiltinRule('("true" | "false") space', []), + 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), + 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), + 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), + 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), + 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), + 'null' : BuiltinRule('"null" space', []), +} + +# TODO: support "uri", "email" string formats +STRING_FORMAT_RULES = { + 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), + 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), + 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), + 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), +} + +DOTALL = '[\\U00000000-\\U0010FFFF]' +DOT = '[^\\x0A\\x0D]' + +RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) + +INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') +GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]') +GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') +GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'} + +NON_LITERAL_SET = set('|.()[]{}*+?') +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') + + +class SchemaConverter: + def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): + self._prop_order = prop_order + self._allow_fetch = allow_fetch + self._dotall = dotall + self._raw_pattern = raw_pattern + self._rules = { + 'space': SPACE_RULE, + } + self._refs = {} + self._refs_being_resolved = set() + + def _format_literal(self, literal): + escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( + lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)) or m.group(0), literal + ) + return f'"{escaped}"' + + def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: + ''' + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + ''' + assert len(literal) > 0, 'Empty literal not supported' + def recurse(i: int): + c = literal[i] + if maybe_escaped_underscores and c == '_': + yield f'[^{c}\\\\]' + yield ' | ' + yield f'"\\\\"? "{c}"' + else: + yield f'[^{c}]' + if i < len(literal) - 1: + yield ' | ' + yield self._format_literal(c) + yield ' (' + yield from recurse(i + 1) + yield ')?' + + return ''.join(('(', *recurse(0), ')')) + + def _not_strings(self, strings): + class TrieNode: + def __init__(self): + self.children = {} + self.is_end_of_string = False + + def insert(self, string): + node = self + for c in string: + node = node.children.setdefault(c, TrieNode()) + node.is_end_of_string = True + + trie = TrieNode() + for s in strings: + trie.insert(s) + + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + out = ['["] ( '] + + def visit(node): + rejects = [] + first = True + for c in sorted(node.children.keys()): + child = node.children[c] + rejects.append(c) + if first: + first = False + else: + out.append(' | ') + out.append(f'[{c}]') + if child.children: + out.append(f' (') + visit(child) + out.append(')') + elif child.is_end_of_string: + out.append(f' {char_rule}+') + if node.children: + if not first: + out.append(' | ') + out.append(f'[^"{"".join(rejects)}] {char_rule}*') + visit(trie) + + out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + return ''.join(out) + + def _add_rule(self, name, rule): + esc_name = INVALID_RULE_CHARS_RE.sub('-', name) + if esc_name not in self._rules or self._rules[esc_name] == rule: + key = esc_name + else: + i = 0 + while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: + i += 1 + key = f'{esc_name}{i}' + self._rules[key] = rule + return key + + def resolve_refs(self, schema: dict, url: str): + ''' + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + ''' + def visit(n: dict): + if isinstance(n, list): + return [visit(x) for x in n] + elif isinstance(n, dict): + ref = n.get('$ref') + if ref is not None and ref not in self._refs: + if ref.startswith('https://'): + assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + import requests + + frag_split = ref.split('#') + base_url = frag_split[0] + + target = self._refs.get(base_url) + if target is None: + target = self.resolve_refs(requests.get(ref).json(), base_url) + self._refs[base_url] = target + + if len(frag_split) == 1 or frag_split[-1] == '': + return target + elif ref.startswith('#/'): + target = schema + ref = f'{url}{ref}' + n['$ref'] = ref + else: + raise ValueError(f'Unsupported ref {ref}') + + for sel in ref.split('#')[-1].split('/')[1:]: + assert target is not None, f'Error resolving ref {ref}: {sel} not in {target}' + if isinstance(target, list): + try: + sel_index = int(sel) + except ValueError: + raise ValueError(f'Error resolving ref {ref}: {sel} not in {target}') + assert 0 <= sel_index < len(target), f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel_index] + else: + assert sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + self._refs[ref] = target + else: + for v in n.values(): + visit(v) + + return n + return visit(schema) + + def _generate_union_rule(self, name, alt_schemas): + return ' | '.join(( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + )) + + def _visit_pattern(self, pattern, name): + ''' + Transforms a regular expression pattern into a GBNF rule. + + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggml-org/llama.cpp/blob/master/grammars/README.md + + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + ''' + + assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' + pattern = pattern[1:-1] + sub_rule_ids = {} + + i = 0 + length = len(pattern) + + def to_rule(s: tuple[str, bool]) -> str: + (txt, is_literal) = s + return "\"" + txt + "\"" if is_literal else txt + + def transform() -> tuple[str, bool]: + ''' + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + ''' + nonlocal i + nonlocal pattern + nonlocal sub_rule_ids + + start = i + # For each component of this sequence, store its string representation and whether it's a literal. + # We only need a flat structure here to apply repetition operators to the last item, and + # to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially + # (GBNF's syntax is luckily very close to regular expressions!) + seq: list[tuple[str, bool]] = [] + + def get_dot(): + if self._dotall: + rule = DOTALL + else: + # Accept any character... except \n and \r line break chars (\x0A and \xOD) + rule = DOT + return self._add_rule(f'dot', rule) + + def join_seq(): + nonlocal seq + ret = [] + for is_literal, g in itertools.groupby(seq, lambda x: x[1]): + if is_literal: + ret.append((''.join(x[0] for x in g), True)) + else: + ret.extend(g) + if len(ret) == 1: + return ret[0] + return (' '.join(to_rule(x) for x in seq), False) + + while i < length: + c = pattern[i] + if c == '.': + seq.append((get_dot(), False)) + i += 1 + elif c == '(': + i += 1 + if i < length: + assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f'({to_rule(transform())})', False)) + elif c == ')': + i += 1 + assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + return join_seq() + elif c == '[': + square_brackets = c + i += 1 + while i < length and pattern[i] != ']': + if pattern[i] == '\\': + square_brackets += pattern[i:i+2] + i += 2 + else: + square_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' + square_brackets += ']' + i += 1 + seq.append((square_brackets, False)) + elif c == '|': + seq.append(('|', False)) + i += 1 + elif c in ('*', '+', '?'): + seq[-1] = (to_rule(seq[-1]) + c, False) + i += 1 + elif c == '{': + curly_brackets = c + i += 1 + while i < length and pattern[i] != '}': + curly_brackets += pattern[i] + i += 1 + assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' + curly_brackets += '}' + i += 1 + nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + min_times = 0 + max_times = None + try: + if len(nums) == 1: + min_times = int(nums[0]) + max_times = min_times + else: + assert len(nums) == 2 + min_times = int(nums[0]) if nums[0] else 0 + max_times = int(nums[1]) if nums[1] else None + except ValueError: + raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') + + (sub, sub_is_literal) = seq[-1] + + if not sub_is_literal: + id = sub_rule_ids.get(sub) + if id is None: + id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + sub_rule_ids[sub] = id + sub = id + + seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) + else: + literal = '' + while i < length: + if pattern[i] == '\\' and i < length - 1: + next = pattern[i + 1] + if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: + i += 1 + literal += pattern[i] + i += 1 + else: + literal += pattern[i:i+2] + i += 2 + elif pattern[i] == '"' and not self._raw_pattern: + literal += '\\"' + i += 1 + elif pattern[i] not in NON_LITERAL_SET and \ + (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): + literal += pattern[i] + i += 1 + else: + break + if literal: + seq.append((literal, True)) + + return join_seq() + + return self._add_rule( + name, + to_rule(transform()) if self._raw_pattern \ + else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") + + + def _resolve_ref(self, ref): + ref_fragment = ref.split('#')[-1] + ref_name = 'ref' + re.sub(r'[^a-zA-Z0-9-]+', '-', ref_fragment) + if ref_name not in self._rules and ref not in self._refs_being_resolved: + self._refs_being_resolved.add(ref) + resolved = self._refs[ref] + ref_name = self.visit(resolved, ref_name) + self._refs_being_resolved.remove(ref) + return ref_name + + def _generate_constant_rule(self, value): + return self._format_literal(json.dumps(value)) + + def visit(self, schema, name): + schema_type = schema.get('type') + schema_format = schema.get('format') + rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + + if (ref := schema.get('$ref')) is not None: + return self._add_rule(rule_name, self._resolve_ref(ref)) + + elif 'oneOf' in schema or 'anyOf' in schema: + return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) + + elif isinstance(schema_type, list): + return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) + + elif 'const' in schema: + return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') + + elif 'enum' in schema: + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' + return self._add_rule(rule_name, rule) + + elif schema_type in (None, 'object') and \ + ('properties' in schema or \ + ('additionalProperties' in schema and schema['additionalProperties'] is not True)): + required = set(schema.get('required', [])) + properties = list(schema.get('properties', {}).items()) + return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + + elif schema_type in (None, 'object', 'string') and 'allOf' in schema: + required = set() + properties = [] + enum_sets = [] + hybrid_name = name + def add_component(comp_schema, is_required): + if (ref := comp_schema.get('$ref')) is not None: + comp_schema = self._refs[ref] + + if 'properties' in comp_schema: + for prop_name, prop_schema in comp_schema['properties'].items(): + properties.append((prop_name, prop_schema)) + if is_required: + required.add(prop_name) + + if 'enum' in comp_schema: + enum_sets.append(set(comp_schema['enum'])) + + for t in schema['allOf']: + if 'anyOf' in t: + for tt in t['anyOf']: + add_component(tt, is_required=False) + else: + add_component(t, is_required=True) + + if enum_sets: + enum_intersection = enum_sets[0] + for s in enum_sets[1:]: + enum_intersection &= s + + if enum_intersection: + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' + return self._add_rule(rule_name, rule) + + return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) + + elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): + items = schema.get('items') or schema['prefixItems'] + if isinstance(items, list): + return self._add_rule( + rule_name, + '"[" space ' + + ' "," space '.join( + self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') + for i, item in enumerate(items)) + + ' "]" space') + else: + item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') + min_items = schema.get("minItems", 0) + max_items = schema.get("maxItems") + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + + elif schema_type in (None, 'string') and 'pattern' in schema: + return self._visit_pattern(schema['pattern'], rule_name) + + elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): + return self._add_primitive( + 'root' if rule_name == 'root' else schema_format, + PRIMITIVE_RULES['uuid'] + ) + + elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: + prim_name = f'{schema_format}-string' + return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) + + elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): + char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + min_len = schema.get('minLength', 0) + max_len = schema.get('maxLength') + + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + + elif schema_type in (None, 'integer') and \ + ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): + min_value = None + max_value = None + if 'minimum' in schema: + min_value = schema['minimum'] + elif 'exclusiveMinimum' in schema: + min_value = schema['exclusiveMinimum'] + 1 + if 'maximum' in schema: + max_value = schema['maximum'] + elif 'exclusiveMaximum' in schema: + max_value = schema['exclusiveMaximum'] - 1 + + out = ["("] + _generate_min_max_int(min_value, max_value, out) + out.append(") space") + return self._add_rule(rule_name, ''.join(out)) + + elif (schema_type == 'object') or (len(schema) == 0): + return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + + else: + assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + + def _add_primitive(self, name: str, rule: BuiltinRule): + n = self._add_rule(name, rule.content) + + for dep in rule.deps: + dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) + assert dep_rule, f'Rule {dep} not known' + if dep not in self._rules: + self._add_primitive(dep, dep_rule) + return n + + def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): + prop_order = self._prop_order + # sort by position in prop_order (if specified) then by original order + sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] + + prop_kv_rule_names = {} + for prop_name, prop_schema in properties: + prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') + prop_kv_rule_names[prop_name] = self._add_rule( + f'{name}{"-" if name else ""}{prop_name}-kv', + fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' + ) + required_props = [k for k in sorted_props if k in required] + optional_props = [k for k in sorted_props if k not in required] + + if additional_properties is not None and additional_properties != False: + sub_name = f'{name}{"-" if name else ""}additional' + value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ + self._add_primitive('value', PRIMITIVE_RULES['value']) + key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ + else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) + + prop_kv_rule_names["*"] = self._add_rule( + f'{sub_name}-kv', + f'{key_rule} ":" space {value_rule}' + ) + optional_props.append("*") + + rule = '"{" space ' + rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) + + if optional_props: + rule += ' (' + if required_props: + rule += ' "," space ( ' + + def get_recursive_refs(ks, first_is_optional): + [k, *rest] = ks + kv_rule_name = prop_kv_rule_names[k] + comma_ref = f'( "," space {kv_rule_name} )' + if first_is_optional: + res = comma_ref + ('*' if k == '*' else '?') + else: + res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') + if len(rest) > 0: + res += ' ' + self._add_rule( + f'{name}{"-" if name else ""}{k}-rest', + get_recursive_refs(rest, first_is_optional=True) + ) + return res + + rule += ' | '.join( + get_recursive_refs(optional_props[i:], first_is_optional=False) + for i in range(len(optional_props)) + ) + if required_props: + rule += ' )' + rule += ' )?' + + rule += ' "}" space' + + return rule + + def format_grammar(self): + return '\n'.join( + f'{name} ::= {rule}' + for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) + ) + + +def main(args_in = None): + parser = argparse.ArgumentParser( + description=''' + Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a + given JSON schema. Only a subset of JSON schema features are supported; more may be + added in the future. + ''', + ) + parser.add_argument( + '--prop-order', + default=[], + type=lambda s: s.split(','), + help=''' + comma-separated property names defining the order of precedence for object properties; + properties not specified here are given lower precedence than those that are, and + are kept in their original order from the schema. Required properties are always + given precedence over optional properties. + ''' + ) + parser.add_argument( + '--allow-fetch', + action='store_true', + default=False, + help='Whether to allow fetching referenced schemas over HTTPS') + parser.add_argument( + '--dotall', + action='store_true', + default=False, + help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') + parser.add_argument( + '--raw-pattern', + action='store_true', + default=False, + help='Treats string patterns as raw patterns w/o quotes (or quote escapes)') + + parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') + args = parser.parse_args(args_in) + + if args.schema.startswith('https://'): + url = args.schema + import requests + schema = requests.get(url).json() + elif args.schema == '-': + url = 'stdin' + schema = json.load(sys.stdin) + else: + url = f'file://{args.schema}' + with open(args.schema) as f: + schema = json.load(f) + converter = SchemaConverter( + prop_order={name: idx for idx, name in enumerate(args.prop_order)}, + allow_fetch=args.allow_fetch, + dotall=args.dotall, + raw_pattern=args.raw_pattern) + schema = converter.resolve_refs(schema, url) + converter.visit(schema, '') + print(converter.format_grammar()) + + +if __name__ == '__main__': + main() diff --git a/llama.cpp/examples/llama.android/.gitignore b/llama.cpp/examples/llama.android/.gitignore new file mode 100644 index 0000000..347e252 --- /dev/null +++ b/llama.cpp/examples/llama.android/.gitignore @@ -0,0 +1,33 @@ +# Gradle files +.gradle/ +build/ + +# Local configuration file (sdk path, etc) +local.properties + +# Log/OS Files +*.log + +# Android Studio generated files and folders +captures/ +.externalNativeBuild/ +.cxx/ +*.apk +output.json + +# IntelliJ +*.iml +.idea/ +misc.xml +deploymentTargetDropDown.xml +render.experimental.xml + +# Keystore files +*.jks +*.keystore + +# Google Services (e.g. APIs or Firebase) +google-services.json + +# Android Profiling +*.hprof diff --git a/llama.cpp/examples/llama.android/app/.gitignore b/llama.cpp/examples/llama.android/app/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/.gitignore @@ -0,0 +1 @@ +/build diff --git a/llama.cpp/examples/llama.android/app/build.gradle.kts b/llama.cpp/examples/llama.android/app/build.gradle.kts new file mode 100644 index 0000000..2edfe98 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/build.gradle.kts @@ -0,0 +1,58 @@ +plugins { + alias(libs.plugins.android.application) + alias(libs.plugins.jetbrains.kotlin.android) +} + +android { + namespace = "com.example.llama" + compileSdk = 36 + + defaultConfig { + applicationId = "com.example.llama.aichat" + + minSdk = 33 + targetSdk = 36 + + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { + useSupportLibrary = true + } + } + + buildTypes { + debug { + isMinifyEnabled = true + isShrinkResources = true + proguardFiles( + getDefaultProguardFile("proguard-android.txt"), + "proguard-rules.pro" + ) + } + release { + isMinifyEnabled = true + isShrinkResources = true + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } +} + +dependencies { + implementation(libs.bundles.androidx) + implementation(libs.material) + + implementation(project(":lib")) + + testImplementation(libs.junit) + androidTestImplementation(libs.androidx.junit) + androidTestImplementation(libs.androidx.espresso.core) +} diff --git a/llama.cpp/examples/llama.android/app/proguard-rules.pro b/llama.cpp/examples/llama.android/app/proguard-rules.pro new file mode 100644 index 0000000..358020d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/proguard-rules.pro @@ -0,0 +1,29 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile + +-keep class com.arm.aichat.* { *; } +-keep class com.arm.aichat.gguf.* { *; } + +-assumenosideeffects class android.util.Log { + public static int v(...); + public static int d(...); +} diff --git a/llama.cpp/examples/llama.android/app/src/main/AndroidManifest.xml b/llama.cpp/examples/llama.android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000..8f7c606 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/AndroidManifest.xml @@ -0,0 +1,27 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android"> + + <application + android:allowBackup="true" + android:dataExtractionRules="@xml/data_extraction_rules" + android:extractNativeLibs="true" + android:fullBackupContent="@xml/backup_rules" + android:icon="@mipmap/ic_launcher_round" + android:label="@string/app_name" + android:roundIcon="@mipmap/ic_launcher_round" + android:supportsRtl="true" + android:theme="@style/Theme.AiChatSample" + > + + <activity + android:name=".MainActivity" + android:exported="true"> + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> + </application> + +</manifest> diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt new file mode 100644 index 0000000..872ec2b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -0,0 +1,275 @@ +package com.example.llama + +import android.net.Uri +import android.os.Bundle +import android.util.Log +import android.widget.EditText +import android.widget.TextView +import android.widget.Toast +import androidx.activity.addCallback +import androidx.activity.enableEdgeToEdge +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.lifecycle.lifecycleScope +import androidx.recyclerview.widget.LinearLayoutManager +import androidx.recyclerview.widget.RecyclerView +import com.arm.aichat.AiChat +import com.arm.aichat.InferenceEngine +import com.arm.aichat.gguf.GgufMetadata +import com.arm.aichat.gguf.GgufMetadataReader +import com.google.android.material.floatingactionbutton.FloatingActionButton +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File +import java.io.FileOutputStream +import java.io.InputStream +import java.util.UUID + +class MainActivity : AppCompatActivity() { + + // Android views + private lateinit var ggufTv: TextView + private lateinit var messagesRv: RecyclerView + private lateinit var userInputEt: EditText + private lateinit var userActionFab: FloatingActionButton + + // Arm AI Chat inference engine + private lateinit var engine: InferenceEngine + private var generationJob: Job? = null + + // Conversation states + private var isModelReady = false + private val messages = mutableListOf<Message>() + private val lastAssistantMsg = StringBuilder() + private val messageAdapter = MessageAdapter(messages) + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + enableEdgeToEdge() + setContentView(R.layout.activity_main) + // View model boilerplate and state management is out of this basic sample's scope + onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") } + + // Find views + ggufTv = findViewById(R.id.gguf) + messagesRv = findViewById(R.id.messages) + messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true } + messagesRv.adapter = messageAdapter + userInputEt = findViewById(R.id.user_input) + userActionFab = findViewById(R.id.fab) + + // Arm AI Chat initialization + lifecycleScope.launch(Dispatchers.Default) { + engine = AiChat.getInferenceEngine(applicationContext) + } + + // Upon CTA button tapped + userActionFab.setOnClickListener { + if (isModelReady) { + // If model is ready, validate input and send to engine + handleUserInput() + } else { + // Otherwise, prompt user to select a GGUF metadata on the device + getContent.launch(arrayOf("*/*")) + } + } + } + + private val getContent = registerForActivityResult( + ActivityResultContracts.OpenDocument() + ) { uri -> + Log.i(TAG, "Selected file uri:\n $uri") + uri?.let { handleSelectedModel(it) } + } + + /** + * Handles the file Uri from [getContent] result + */ + private fun handleSelectedModel(uri: Uri) { + // Update UI states + userActionFab.isEnabled = false + userInputEt.hint = "Parsing GGUF..." + ggufTv.text = "Parsing metadata from selected file \n$uri" + + lifecycleScope.launch(Dispatchers.IO) { + // Parse GGUF metadata + Log.i(TAG, "Parsing GGUF metadata...") + contentResolver.openInputStream(uri)?.use { + GgufMetadataReader.create().readStructuredMetadata(it) + }?.let { metadata -> + // Update UI to show GGUF metadata to user + Log.i(TAG, "GGUF parsed: \n$metadata") + withContext(Dispatchers.Main) { + ggufTv.text = metadata.toString() + } + + // Ensure the model file is available + val modelName = metadata.filename() + FILE_EXTENSION_GGUF + contentResolver.openInputStream(uri)?.use { input -> + ensureModelFile(modelName, input) + }?.let { modelFile -> + loadModel(modelName, modelFile) + + withContext(Dispatchers.Main) { + isModelReady = true + userInputEt.hint = "Type and send a message!" + userInputEt.isEnabled = true + userActionFab.setImageResource(R.drawable.outline_send_24) + userActionFab.isEnabled = true + } + } + } + } + } + + /** + * Prepare the model file within app's private storage + */ + private suspend fun ensureModelFile(modelName: String, input: InputStream) = + withContext(Dispatchers.IO) { + File(ensureModelsDirectory(), modelName).also { file -> + // Copy the file into local storage if not yet done + if (!file.exists()) { + Log.i(TAG, "Start copying file to $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Copying file..." + } + + FileOutputStream(file).use { input.copyTo(it) } + Log.i(TAG, "Finished copying file to $modelName") + } else { + Log.i(TAG, "File already exists $modelName") + } + } + } + + /** + * Load the model file from the app private storage + */ + private suspend fun loadModel(modelName: String, modelFile: File) = + withContext(Dispatchers.IO) { + Log.i(TAG, "Loading model $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Loading model..." + } + engine.loadModel(modelFile.path) + } + + /** + * Validate and send the user message into [InferenceEngine] + */ + private fun handleUserInput() { + userInputEt.text.toString().also { userMsg -> + if (userMsg.isEmpty()) { + Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show() + } else { + userInputEt.text = null + userInputEt.isEnabled = false + userActionFab.isEnabled = false + + // Update message states + messages.add(Message(UUID.randomUUID().toString(), userMsg, true)) + lastAssistantMsg.clear() + messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false)) + + generationJob = lifecycleScope.launch(Dispatchers.Default) { + engine.sendUserPrompt(userMsg) + .onCompletion { + withContext(Dispatchers.Main) { + userInputEt.isEnabled = true + userActionFab.isEnabled = true + } + }.collect { token -> + withContext(Dispatchers.Main) { + val messageCount = messages.size + check(messageCount > 0 && !messages[messageCount - 1].isUser) + + messages.removeAt(messageCount - 1).copy( + content = lastAssistantMsg.append(token).toString() + ).let { messages.add(it) } + + messageAdapter.notifyItemChanged(messages.size - 1) + } + } + } + } + } + } + + /** + * Run a benchmark with the model file + */ + @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers") + private suspend fun runBenchmark(modelName: String, modelFile: File) = + withContext(Dispatchers.Default) { + Log.i(TAG, "Starts benchmarking $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Running benchmark..." + } + engine.bench( + pp=BENCH_PROMPT_PROCESSING_TOKENS, + tg=BENCH_TOKEN_GENERATION_TOKENS, + pl=BENCH_SEQUENCE, + nr=BENCH_REPETITION + ).let { result -> + messages.add(Message(UUID.randomUUID().toString(), result, false)) + withContext(Dispatchers.Main) { + messageAdapter.notifyItemChanged(messages.size - 1) + } + } + } + + /** + * Create the `models` directory if not exist. + */ + private fun ensureModelsDirectory() = + File(filesDir, DIRECTORY_MODELS).also { + if (it.exists() && !it.isDirectory) { it.delete() } + if (!it.exists()) { it.mkdir() } + } + + override fun onStop() { + generationJob?.cancel() + super.onStop() + } + + override fun onDestroy() { + engine.destroy() + super.onDestroy() + } + + companion object { + private val TAG = MainActivity::class.java.simpleName + + private const val DIRECTORY_MODELS = "models" + private const val FILE_EXTENSION_GGUF = ".gguf" + + private const val BENCH_PROMPT_PROCESSING_TOKENS = 512 + private const val BENCH_TOKEN_GENERATION_TOKENS = 128 + private const val BENCH_SEQUENCE = 1 + private const val BENCH_REPETITION = 3 + } +} + +fun GgufMetadata.filename() = when { + basic.name != null -> { + basic.name?.let { name -> + basic.sizeLabel?.let { size -> + "$name-$size" + } ?: name + } + } + architecture?.architecture != null -> { + architecture?.architecture?.let { arch -> + basic.uuid?.let { uuid -> + "$arch-$uuid" + } ?: "$arch-${System.currentTimeMillis()}" + } + } + else -> { + "model-${System.currentTimeMillis().toHexString()}" + } +} diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt new file mode 100644 index 0000000..0439f96 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt @@ -0,0 +1,51 @@ +package com.example.llama + +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import android.widget.TextView +import androidx.recyclerview.widget.RecyclerView + +data class Message( + val id: String, + val content: String, + val isUser: Boolean +) + +class MessageAdapter( + private val messages: List<Message> +) : RecyclerView.Adapter<RecyclerView.ViewHolder>() { + + companion object { + private const val VIEW_TYPE_USER = 1 + private const val VIEW_TYPE_ASSISTANT = 2 + } + + override fun getItemViewType(position: Int): Int { + return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT + } + + override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder { + val layoutInflater = LayoutInflater.from(parent.context) + return if (viewType == VIEW_TYPE_USER) { + val view = layoutInflater.inflate(R.layout.item_message_user, parent, false) + UserMessageViewHolder(view) + } else { + val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false) + AssistantMessageViewHolder(view) + } + } + + override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) { + val message = messages[position] + if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) { + val textView = holder.itemView.findViewById<TextView>(R.id.msg_content) + textView.text = message.content + } + } + + override fun getItemCount(): Int = messages.size + + class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) + class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) +} diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml new file mode 100644 index 0000000..f90c3db --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml @@ -0,0 +1,4 @@ +<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle"> + <solid android:color="#E5E5EA" /> + <corners android:radius="16dp" /> +</shape> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml new file mode 100644 index 0000000..3ca7dae --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml @@ -0,0 +1,4 @@ +<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle"> + <solid android:color="#4285F4" /> + <corners android:radius="16dp" /> +</shape> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000..07d5da9 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ +<?xml version="1.0" encoding="utf-8"?> +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="108dp" + android:height="108dp" + android:viewportWidth="108" + android:viewportHeight="108"> + <path + android:fillColor="#3DDC84" + android:pathData="M0,0h108v108h-108z" /> + <path + android:fillColor="#00000000" + android:pathData="M9,0L9,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,0L19,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M29,0L29,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M39,0L39,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M49,0L49,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M59,0L59,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M69,0L69,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M79,0L79,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M89,0L89,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M99,0L99,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,9L108,9" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,19L108,19" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,29L108,29" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,39L108,39" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,49L108,49" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,59L108,59" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,69L108,69" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,79L108,79" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,89L108,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,99L108,99" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,29L89,29" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,39L89,39" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,49L89,49" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,59L89,59" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,69L89,69" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,79L89,79" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M29,19L29,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M39,19L39,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M49,19L49,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M59,19L59,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M69,19L69,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M79,19L79,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 0000000..7706ab9 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:aapt="http://schemas.android.com/aapt" + android:width="108dp" + android:height="108dp" + android:viewportWidth="108" + android:viewportHeight="108"> + <path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z"> + <aapt:attr name="android:fillColor"> + <gradient + android:endX="85.84757" + android:endY="92.4963" + android:startX="42.9492" + android:startY="49.59793" + android:type="linear"> + <item + android:color="#44000000" + android:offset="0.0" /> + <item + android:color="#00000000" + android:offset="1.0" /> + </gradient> + </aapt:attr> + </path> + <path + android:fillColor="#FFFFFF" + android:fillType="nonZero" + android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z" + android:strokeWidth="1" + android:strokeColor="#00000000" /> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml new file mode 100644 index 0000000..f58b501 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml @@ -0,0 +1,10 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="24dp" + android:height="24dp" + android:viewportWidth="24" + android:viewportHeight="24" + android:tint="?attr/colorControlNormal"> + <path + android:fillColor="@android:color/white" + android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml new file mode 100644 index 0000000..712adc0 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml @@ -0,0 +1,11 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="24dp" + android:height="24dp" + android:viewportWidth="24" + android:viewportHeight="24" + android:tint="?attr/colorControlNormal" + android:autoMirrored="true"> + <path + android:fillColor="@android:color/white" + android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/layout/activity_main.xml b/llama.cpp/examples/llama.android/app/src/main/res/layout/activity_main.xml new file mode 100644 index 0000000..d15772b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,77 @@ +<?xml version="1.0" encoding="utf-8"?> +<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:app="http://schemas.android.com/apk/res-auto" + xmlns:tools="http://schemas.android.com/tools" + android:id="@+id/main" + android:layout_height="match_parent" + android:layout_width="match_parent"> + + <LinearLayout + android:fitsSystemWindows="true" + android:layout_width="match_parent" + android:layout_height="match_parent" + android:orientation="vertical" + android:layout_marginEnd="4dp" + tools:context=".MainActivity"> + + <ScrollView + android:layout_width="match_parent" + android:layout_height="0dp" + android:layout_weight="1" + android:fadeScrollbars="false"> + + <TextView + android:id="@+id/gguf" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:padding="16dp" + android:text="Selected GGUF model's metadata will show here." + style="@style/TextAppearance.MaterialComponents.Body2" /> + + </ScrollView> + + <com.google.android.material.divider.MaterialDivider + android:layout_width="match_parent" + android:layout_height="2dp" + android:layout_marginHorizontal="16dp" /> + + <androidx.recyclerview.widget.RecyclerView + android:id="@+id/messages" + android:layout_width="match_parent" + android:layout_height="0dp" + android:layout_weight="4" + android:fadeScrollbars="false" + android:scrollbars="vertical" + app:reverseLayout="true" + tools:listitem="@layout/item_message_assistant"/> + + <LinearLayout + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:orientation="horizontal" + android:paddingStart="16dp" + android:paddingEnd="4dp"> + + <EditText + android:id="@+id/user_input" + android:enabled="false" + android:layout_width="0dp" + android:layout_weight="1" + android:layout_height="match_parent" + android:padding="8dp" + style="@style/TextAppearance.MaterialComponents.Body2" + android:hint="Please first pick a GGUF model file to import." /> + + <com.google.android.material.floatingactionbutton.FloatingActionButton + android:id="@+id/fab" + android:enabled="true" + style="@style/Widget.Material3.FloatingActionButton.Primary" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:layout_margin="12dp" + android:src="@drawable/outline_folder_open_24" /> + + </LinearLayout> + + </LinearLayout> +</androidx.constraintlayout.widget.ConstraintLayout> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml new file mode 100644 index 0000000..2c8e4bc --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml @@ -0,0 +1,16 @@ +<?xml version="1.0" encoding="utf-8"?> +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:layout_marginHorizontal="16dp" + android:layout_marginVertical="8dp" + android:gravity="start"> + + <TextView + android:id="@+id/msg_content" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:background="@drawable/bg_assistant_message" + android:padding="12dp" + android:textColor="@android:color/black" /> +</LinearLayout> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_user.xml b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_user.xml new file mode 100644 index 0000000..5aa79f2 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_user.xml @@ -0,0 +1,16 @@ +<?xml version="1.0" encoding="utf-8"?> +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:layout_marginHorizontal="16dp" + android:layout_marginVertical="8dp" + android:gravity="end"> + + <TextView + android:id="@+id/msg_content" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:background="@drawable/bg_user_message" + android:padding="12dp" + android:textColor="@android:color/white" /> +</LinearLayout> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml new file mode 100644 index 0000000..b3e26b4 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android"> + <background android:drawable="@drawable/ic_launcher_background" /> + <foreground android:drawable="@drawable/ic_launcher_foreground" /> + <monochrome android:drawable="@drawable/ic_launcher_foreground" /> +</adaptive-icon> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml new file mode 100644 index 0000000..b3e26b4 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android"> + <background android:drawable="@drawable/ic_launcher_background" /> + <foreground android:drawable="@drawable/ic_launcher_foreground" /> + <monochrome android:drawable="@drawable/ic_launcher_foreground" /> +</adaptive-icon> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..c209e78 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..b2dfe3d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..4f0f1d6 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..62b611d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..948a307 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..1b9a695 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..28d4b77 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..9287f50 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..aa7d642 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..9126ae3 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/values/colors.xml b/llama.cpp/examples/llama.android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000..ca1931b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ +<?xml version="1.0" encoding="utf-8"?> +<resources> + <color name="purple_200">#FFBB86FC</color> + <color name="purple_500">#FF6200EE</color> + <color name="purple_700">#FF3700B3</color> + <color name="teal_200">#FF03DAC5</color> + <color name="teal_700">#FF018786</color> + <color name="black">#FF000000</color> + <color name="white">#FFFFFFFF</color> +</resources> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/values/strings.xml b/llama.cpp/examples/llama.android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000..36059fc --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ +<resources> + <string name="app_name">AI Chat basic sample</string> +</resources> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/values/themes.xml b/llama.cpp/examples/llama.android/app/src/main/res/values/themes.xml new file mode 100644 index 0000000..2e4fdad --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/values/themes.xml @@ -0,0 +1,10 @@ +<?xml version="1.0" encoding="utf-8"?> +<resources> + + <style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar"> + <!-- Customize your light theme here. --> + <!-- <item name="colorPrimary">@color/my_light_primary</item> --> + </style> + + <style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" /> +</resources> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/xml/backup_rules.xml b/llama.cpp/examples/llama.android/app/src/main/res/xml/backup_rules.xml new file mode 100644 index 0000000..148c18b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/xml/backup_rules.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="utf-8"?><!-- + Sample backup rules file; uncomment and customize as necessary. + See https://developer.android.com/guide/topics/data/autobackup + for details. + Note: This file is ignored for devices older that API 31 + See https://developer.android.com/about/versions/12/backup-restore +--> +<full-backup-content> + <!-- + <include domain="sharedpref" path="."/> + <exclude domain="sharedpref" path="device.xml"/> +--> +</full-backup-content> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml b/llama.cpp/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml new file mode 100644 index 0000000..0c4f95c --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml @@ -0,0 +1,19 @@ +<?xml version="1.0" encoding="utf-8"?><!-- + Sample data extraction rules file; uncomment and customize as necessary. + See https://developer.android.com/about/versions/12/backup-restore#xml-changes + for details. +--> +<data-extraction-rules> + <cloud-backup> + <!-- TODO: Use <include> and <exclude> to control what is backed up. + <include .../> + <exclude .../> + --> + </cloud-backup> + <!-- + <device-transfer> + <include .../> + <exclude .../> + </device-transfer> + --> +</data-extraction-rules> diff --git a/llama.cpp/examples/llama.android/build.gradle.kts b/llama.cpp/examples/llama.android/build.gradle.kts new file mode 100644 index 0000000..076a0f1 --- /dev/null +++ b/llama.cpp/examples/llama.android/build.gradle.kts @@ -0,0 +1,6 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +plugins { + alias(libs.plugins.android.application) apply false + alias(libs.plugins.android.library) apply false + alias(libs.plugins.jetbrains.kotlin.android) apply false +} diff --git a/llama.cpp/examples/llama.android/gradle.properties b/llama.cpp/examples/llama.android/gradle.properties new file mode 100644 index 0000000..8888cc9 --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle.properties @@ -0,0 +1,24 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app's APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official +# Enables namespacing of each library's R class so that its R class includes only the +# resources declared in the library itself and none from the library's dependencies, +# thereby reducing the size of the R class for that library +android.nonTransitiveRClass=true +android.native.buildOutput=verbose diff --git a/llama.cpp/examples/llama.android/gradle/libs.versions.toml b/llama.cpp/examples/llama.android/gradle/libs.versions.toml new file mode 100644 index 0000000..8ff2afd --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle/libs.versions.toml @@ -0,0 +1,53 @@ +[versions] + +# Plugins +agp = "8.13.2" +kotlin = "2.3.0" + +# AndroidX +activity = "1.12.2" +appcompat = "1.7.1" +core-ktx = "1.17.0" +constraint-layout = "2.2.1" +datastore-preferences = "1.2.0" + +# Material +material = "1.13.0" + +# Testing +espresso-core = "3.7.0" +androidx-junit = "1.3.0" +junit = "4.13.2" + + +[plugins] +android-application = { id = "com.android.application", version.ref = "agp" } +android-library = { id = "com.android.library", version.ref = "agp" } +jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } + + +[libraries] + +# AndroidX +androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" } +androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" } +androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" } +androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" } +androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" } + +#Material +material = { group = "com.google.android.material", name = "material", version.ref = "material" } + +# Testing +androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" } +androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" } +junit = { group = "junit", name = "junit", version.ref = "junit" } + +[bundles] +androidx = [ + "androidx-activity", + "androidx-appcompat", + "androidx-constraintlayout", + "androidx-core-ktx", + "androidx-datastore-preferences", +] diff --git a/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.jar b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.jar Binary files differnew file mode 100644 index 0000000..e708b1c --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.jar diff --git a/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.properties b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..6b993e9 --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Tue Apr 01 11:15:06 PDT 2025 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/llama.cpp/examples/llama.android/gradlew b/llama.cpp/examples/llama.android/gradlew new file mode 100755 index 0000000..4f906e0 --- /dev/null +++ b/llama.cpp/examples/llama.android/gradlew @@ -0,0 +1,185 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/llama.cpp/examples/llama.android/lib/.gitignore b/llama.cpp/examples/llama.android/lib/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/.gitignore @@ -0,0 +1 @@ +/build diff --git a/llama.cpp/examples/llama.android/lib/build.gradle.kts b/llama.cpp/examples/llama.android/lib/build.gradle.kts new file mode 100644 index 0000000..9b290d6 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/build.gradle.kts @@ -0,0 +1,78 @@ +plugins { + alias(libs.plugins.android.library) + alias(libs.plugins.jetbrains.kotlin.android) +} + +android { + namespace = "com.arm.aichat" + compileSdk = 36 + + ndkVersion = "29.0.13113456" + + defaultConfig { + minSdk = 33 + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles("consumer-rules.pro") + + ndk { + abiFilters += listOf("arm64-v8a", "x86_64") + } + externalNativeBuild { + cmake { + arguments += "-DCMAKE_BUILD_TYPE=Release" + arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG" + arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON" + + arguments += "-DBUILD_SHARED_LIBS=ON" + arguments += "-DLLAMA_BUILD_COMMON=ON" + arguments += "-DLLAMA_OPENSSL=OFF" + + arguments += "-DGGML_NATIVE=OFF" + arguments += "-DGGML_BACKEND_DL=ON" + arguments += "-DGGML_CPU_ALL_VARIANTS=ON" + arguments += "-DGGML_LLAMAFILE=OFF" + } + } + aarMetadata { + minCompileSdk = 35 + } + } + externalNativeBuild { + cmake { + path("src/main/cpp/CMakeLists.txt") + version = "3.31.6" + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlin { + jvmToolchain(17) + + compileOptions { + targetCompatibility = JavaVersion.VERSION_17 + } + } + + packaging { + resources { + excludes += "/META-INF/{AL2.0,LGPL2.1}" + } + } + + publishing { + singleVariant("release") { + withJavadocJar() + } + } +} + +dependencies { + implementation(libs.androidx.core.ktx) + implementation(libs.androidx.datastore.preferences) + + testImplementation(libs.junit) + androidTestImplementation(libs.androidx.junit) +} diff --git a/llama.cpp/examples/llama.android/lib/consumer-rules.pro b/llama.cpp/examples/llama.android/lib/consumer-rules.pro new file mode 100644 index 0000000..e6eb6f5 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/consumer-rules.pro @@ -0,0 +1,8 @@ +-keep class com.arm.aichat.* { *; } +-keep class com.arm.aichat.gguf.* { *; } + +-keepclasseswithmembernames class * { + native <methods>; +} + +-keep class kotlin.Metadata { *; } diff --git a/llama.cpp/examples/llama.android/lib/proguard-rules.pro b/llama.cpp/examples/llama.android/lib/proguard-rules.pro new file mode 100644 index 0000000..f1b4245 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/llama.cpp/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt b/llama.cpp/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt new file mode 100644 index 0000000..05d6ab5 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt @@ -0,0 +1,24 @@ +package android.llama.cpp + +import androidx.test.platform.app.InstrumentationRegistry +import androidx.test.ext.junit.runners.AndroidJUnit4 + +import org.junit.Test +import org.junit.runner.RunWith + +import org.junit.Assert.* + +/** + * Instrumented test, which will execute on an Android device. + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +@RunWith(AndroidJUnit4::class) +class ExampleInstrumentedTest { + @Test + fun useAppContext() { + // Context of the app under test. + val appContext = InstrumentationRegistry.getInstrumentation().targetContext + assertEquals("android.llama.cpp.test", appContext.packageName) + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/AndroidManifest.xml b/llama.cpp/examples/llama.android/lib/src/main/AndroidManifest.xml new file mode 100644 index 0000000..8bdb7e1 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android"> + +</manifest> diff --git a/llama.cpp/examples/llama.android/lib/src/main/cpp/CMakeLists.txt b/llama.cpp/examples/llama.android/lib/src/main/cpp/CMakeLists.txt new file mode 100644 index 0000000..7862c61 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/cpp/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required(VERSION 3.31.6) + +project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED true) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE) + +# -------------------------------------------------------------------------- +# AI Chat library +# -------------------------------------------------------------------------- + +if(DEFINED ANDROID_ABI) + message(STATUS "Detected Android ABI: ${ANDROID_ABI}") + if(ANDROID_ABI STREQUAL "arm64-v8a") + set(GGML_SYSTEM_ARCH "ARM") + set(GGML_CPU_KLEIDIAI ON) + set(GGML_OPENMP ON) + elseif(ANDROID_ABI STREQUAL "x86_64") + set(GGML_SYSTEM_ARCH "x86") + set(GGML_CPU_KLEIDIAI OFF) + set(GGML_OPENMP OFF) + else() + message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}") + endif() +endif() + +set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../) +add_subdirectory(${LLAMA_SRC} build-llama) + +add_library(${CMAKE_PROJECT_NAME} SHARED + ai_chat.cpp) + +target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE + GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH} + GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}> + GGML_OPENMP=$<BOOL:${GGML_OPENMP}> +) + +target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE + ${LLAMA_SRC} + ${LLAMA_SRC}/common + ${LLAMA_SRC}/include + ${LLAMA_SRC}/ggml/include + ${LLAMA_SRC}/ggml/src) + +target_link_libraries(${CMAKE_PROJECT_NAME} + llama + common + android + log) diff --git a/llama.cpp/examples/llama.android/lib/src/main/cpp/ai_chat.cpp b/llama.cpp/examples/llama.android/lib/src/main/cpp/ai_chat.cpp new file mode 100644 index 0000000..9e460ac --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/cpp/ai_chat.cpp @@ -0,0 +1,565 @@ +#include <android/log.h> +#include <jni.h> +#include <iomanip> +#include <cmath> +#include <string> +#include <unistd.h> +#include <sampling.h> + +#include "logging.h" +#include "chat.h" +#include "common.h" +#include "llama.h" + +template<class T> +static std::string join(const std::vector<T> &values, const std::string &delim) { + std::ostringstream str; + for (size_t i = 0; i < values.size(); i++) { + str << values[i]; + if (i < values.size() - 1) { str << delim; } + } + return str.str(); +} + +/** + * LLama resources: context, model, batch and sampler + */ +constexpr int N_THREADS_MIN = 2; +constexpr int N_THREADS_MAX = 4; +constexpr int N_THREADS_HEADROOM = 2; + +constexpr int DEFAULT_CONTEXT_SIZE = 8192; +constexpr int OVERFLOW_HEADROOM = 4; +constexpr int BATCH_SIZE = 512; +constexpr float DEFAULT_SAMPLER_TEMP = 0.3f; + +static llama_model * g_model; +static llama_context * g_context; +static llama_batch g_batch; +static common_chat_templates_ptr g_chat_templates; +static common_sampler * g_sampler; + +extern "C" +JNIEXPORT void JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) { + // Set llama log handler to Android + llama_log_set(aichat_android_log_callback, nullptr); + + // Loading all CPU backend variants + const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0); + LOGi("Loading backends from %s", path_to_backend); + ggml_backend_load_all_from_path(path_to_backend); + env->ReleaseStringUTFChars(nativeLibDir, path_to_backend); + + // Initialize backends + llama_backend_init(); + LOGi("Backend initiated; Log handler set."); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) { + llama_model_params model_params = llama_model_default_params(); + + const auto *model_path = env->GetStringUTFChars(jmodel_path, 0); + LOGd("%s: Loading model from: \n%s\n", __func__, model_path); + + auto *model = llama_model_load_from_file(model_path, model_params); + env->ReleaseStringUTFChars(jmodel_path, model_path); + if (!model) { + return 1; + } + g_model = model; + return 0; +} + +static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) { + if (!model) { + LOGe("%s: model cannot be null", __func__); + return nullptr; + } + + // Multi-threading setup + const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, + (int) sysconf(_SC_NPROCESSORS_ONLN) - + N_THREADS_HEADROOM)); + LOGi("%s: Using %d threads", __func__, n_threads); + + // Context parameters setup + llama_context_params ctx_params = llama_context_default_params(); + const int trained_context_size = llama_model_n_ctx_train(model); + if (n_ctx > trained_context_size) { + LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...", + __func__, trained_context_size, n_ctx); + } + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = BATCH_SIZE; + ctx_params.n_ubatch = BATCH_SIZE; + ctx_params.n_threads = n_threads; + ctx_params.n_threads_batch = n_threads; + auto *context = llama_init_from_model(g_model, ctx_params); + if (context == nullptr) { + LOGe("%s: llama_new_context_with_model() returned null)", __func__); + } + return context; +} + +static common_sampler *new_sampler(float temp) { + common_params_sampling sparams; + sparams.temp = temp; + return common_sampler_init(g_model, sparams); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) { + auto *context = init_context(g_model); + if (!context) { return 1; } + g_context = context; + g_batch = llama_batch_init(BATCH_SIZE, 0, 1); + g_chat_templates = common_chat_templates_init(g_model, ""); + g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP); + return 0; +} + +static std::string get_backend() { + std::vector<std::string> backends; + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto *reg = ggml_backend_reg_get(i); + std::string name = ggml_backend_reg_name(reg); + if (name != "CPU") { + backends.push_back(ggml_backend_reg_name(reg)); + } + } + return backends.empty() ? "CPU" : join(backends, ","); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) { + return env->NewStringUTF(llama_print_system_info()); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, + jint pl, jint nr) { + auto *context = init_context(g_model, pp); + if (!context) { + const auto *const err_msg = "Fail to init_context! Bench aborted."; + LOGe(err_msg); + return env->NewStringUTF(err_msg); + } + + auto pp_avg = 0.0; + auto tg_avg = 0.0; + auto pp_std = 0.0; + auto tg_std = 0.0; + + const uint32_t n_ctx = llama_n_ctx(context); + LOGi("n_ctx = %d", n_ctx); + + int i, j; + int nri; + for (nri = 0; nri < nr; nri++) { + LOGi("Benchmark prompt processing (pp = %d)", pp); + + common_batch_clear(g_batch); + + const int n_tokens = pp; + for (i = 0; i < n_tokens; i++) { + common_batch_add(g_batch, 0, i, {0}, false); + } + + g_batch.logits[g_batch.n_tokens - 1] = true; + llama_memory_clear(llama_get_memory(context), false); + + const auto t_pp_start = ggml_time_us(); + if (llama_decode(context, g_batch) != 0) { + LOGe("llama_decode() failed during prompt processing"); + } + const auto t_pp_end = ggml_time_us(); + + // bench text generation + + LOGi("Benchmark text generation (tg = %d)", tg); + + llama_memory_clear(llama_get_memory(context), false); + const auto t_tg_start = ggml_time_us(); + for (i = 0; i < tg; i++) { + common_batch_clear(g_batch); + for (j = 0; j < pl; j++) { + common_batch_add(g_batch, 0, i, {j}, true); + } + + if (llama_decode(context, g_batch) != 0) { + LOGe("llama_decode() failed during text generation"); + } + } + const auto t_tg_end = ggml_time_us(); + + llama_memory_clear(llama_get_memory(context), false); + + const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; + const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; + + const auto speed_pp = double(pp) / t_pp; + const auto speed_tg = double(pl * tg) / t_tg; + + pp_avg += speed_pp; + tg_avg += speed_tg; + + pp_std += speed_pp * speed_pp; + tg_std += speed_tg * speed_tg; + + LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); + } + + llama_free(context); + + pp_avg /= double(nr); + tg_avg /= double(nr); + + if (nr > 1) { + pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); + tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); + } else { + pp_std = 0; + tg_std = 0; + } + + char model_desc[128]; + llama_model_desc(g_model, model_desc, sizeof(model_desc)); + + const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0; + const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9; + + const auto backend = get_backend(); + std::stringstream result; + result << std::setprecision(3); + result << "| model | size | params | backend | test | t/s |\n"; + result << "| --- | --- | --- | --- | --- | --- |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; + return env->NewStringUTF(result.str().c_str()); +} + + +/** + * Completion loop's long-term states: + * - chat management + * - position tracking + */ +constexpr const char *ROLE_SYSTEM = "system"; +constexpr const char *ROLE_USER = "user"; +constexpr const char *ROLE_ASSISTANT = "assistant"; + +static std::vector<common_chat_msg> chat_msgs; +static llama_pos system_prompt_position; +static llama_pos current_position; + +static void reset_long_term_states(const bool clear_kv_cache = true) { + chat_msgs.clear(); + system_prompt_position = 0; + current_position = 0; + + if (clear_kv_cache) + llama_memory_clear(llama_get_memory(g_context), false); +} + +/** + * TODO-hyin: implement sliding-window version as a better alternative + * + * Context shifting by discarding the older half of the tokens appended after system prompt: + * - take the [system_prompt_position] first tokens from the original prompt + * - take half of the last (system_prompt_position - system_prompt_position) tokens + * - recompute the logits in batches + */ +static void shift_context() { + const int n_discard = (current_position - system_prompt_position) / 2; + LOGi("%s: Discarding %d tokens", __func__, n_discard); + llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard); + llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard); + current_position -= n_discard; + LOGi("%s: Context shifting done! Current position: %d", __func__, current_position); +} + +static std::string chat_add_and_format(const std::string &role, const std::string &content) { + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; + auto formatted = common_chat_format_single( + g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false); + chat_msgs.push_back(new_msg); + LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str()); + return formatted; +} + +/** + * Completion loop's short-term states: + * - stop generation position + * - token chars caching + * - current assistant message being generated + */ +static llama_pos stop_generation_position; +static std::string cached_token_chars; +static std::ostringstream assistant_ss; + +static void reset_short_term_states() { + stop_generation_position = 0; + cached_token_chars.clear(); + assistant_ss.str(""); +} + +static int decode_tokens_in_batches( + llama_context *context, + llama_batch &batch, + const llama_tokens &tokens, + const llama_pos start_pos, + const bool compute_last_logit = false) { + // Process tokens in batches using the global batch + LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos); + for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { + const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); + common_batch_clear(batch); + LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i); + + // Shift context if current batch cannot fit into the context + if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("%s: Current batch won't fit into context! Shifting...", __func__); + shift_context(); + } + + // Add tokens to the batch with proper positions + for (int j = 0; j < cur_batch_size; j++) { + const llama_token token_id = tokens[i + j]; + const llama_pos position = start_pos + i + j; + const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); + common_batch_add(batch, token_id, position, {0}, want_logit); + } + + // Decode this batch + const int decode_result = llama_decode(context, batch); + if (decode_result) { + LOGe("%s: llama_decode failed w/ %d", __func__, decode_result); + return 1; + } + } + return 0; +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt( + JNIEnv *env, + jobject /*unused*/, + jstring jsystem_prompt +) { + // Reset long-term & short-term states + reset_long_term_states(); + reset_short_term_states(); + + // Obtain system prompt from JEnv + const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr); + LOGd("%s: System prompt received: \n%s", __func__, system_prompt); + std::string formatted_system_prompt(system_prompt); + env->ReleaseStringUTFChars(jsystem_prompt, system_prompt); + + // Format system prompt if applicable + const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); + if (has_chat_template) { + formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt); + } + + // Tokenize system prompt + const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, + has_chat_template, has_chat_template); + for (auto id: system_tokens) { + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); + } + + // Handle context overflow + const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; + if ((int) system_tokens.size() > max_batch_size) { + LOGe("%s: System prompt too long for context! %d tokens, max: %d", + __func__, (int) system_tokens.size(), max_batch_size); + return 1; + } + + // Decode system tokens in batches + if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) { + LOGe("%s: llama_decode() failed!", __func__); + return 2; + } + + // Update position + system_prompt_position = current_position = (int) system_tokens.size(); + return 0; +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt( + JNIEnv *env, + jobject /*unused*/, + jstring juser_prompt, + jint n_predict +) { + // Reset short-term states + reset_short_term_states(); + + // Obtain and tokenize user prompt + const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr); + LOGd("%s: User prompt received: \n%s", __func__, user_prompt); + std::string formatted_user_prompt(user_prompt); + env->ReleaseStringUTFChars(juser_prompt, user_prompt); + + // Format user prompt if applicable + const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); + if (has_chat_template) { + formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt); + } + + // Decode formatted user prompts + auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); + for (auto id: user_tokens) { + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); + } + + // Ensure user prompt doesn't exceed the context size by truncating if necessary. + const int user_prompt_size = (int) user_tokens.size(); + const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; + if (user_prompt_size > max_batch_size) { + const int skipped_tokens = user_prompt_size - max_batch_size; + user_tokens.resize(max_batch_size); + LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens); + } + + // Decode user tokens in batches + if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) { + LOGe("%s: llama_decode() failed!", __func__); + return 2; + } + + // Update position + current_position += user_prompt_size; + stop_generation_position = current_position + user_prompt_size + n_predict; + return 0; +} + +static bool is_valid_utf8(const char *string) { + if (!string) { return true; } + + const auto *bytes = (const unsigned char *) string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + return true; +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken( + JNIEnv *env, + jobject /*unused*/ +) { + // Infinite text generation via context shifting + if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("%s: Context full! Shifting...", __func__); + shift_context(); + } + + // Stop if reaching the marked position + if (current_position >= stop_generation_position) { + LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position); + return nullptr; + } + + // Sample next token + const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1); + common_sampler_accept(g_sampler, new_token_id, true); + + // Populate the batch with new token, then decode + common_batch_clear(g_batch); + common_batch_add(g_batch, new_token_id, current_position, {0}, true); + if (llama_decode(g_context, g_batch) != 0) { + LOGe("%s: llama_decode() failed for generated token", __func__); + return nullptr; + } + + // Update position + current_position++; + + // Stop if next token is EOG + if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { + LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); + chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str()); + return nullptr; + } + + // If not EOG, convert to text + auto new_token_chars = common_token_to_piece(g_context, new_token_id); + cached_token_chars += new_token_chars; + + // Create and return a valid UTF-8 Java string + jstring result = nullptr; + if (is_valid_utf8(cached_token_chars.c_str())) { + result = env->NewStringUTF(cached_token_chars.c_str()); + LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); + + assistant_ss << cached_token_chars; + cached_token_chars.clear(); + } else { + LOGv("id: %d,\tappend to cache", new_token_id); + result = env->NewStringUTF(""); + } + return result; +} + + +extern "C" +JNIEXPORT void JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) { + // Reset long-term & short-term states + reset_long_term_states(); + reset_short_term_states(); + + // Free up resources + common_sampler_free(g_sampler); + g_chat_templates.reset(); + llama_batch_free(g_batch); + llama_free(g_context); + llama_model_free(g_model); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) { + llama_backend_free(); +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/cpp/logging.h b/llama.cpp/examples/llama.android/lib/src/main/cpp/logging.h new file mode 100644 index 0000000..2e768d2 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/cpp/logging.h @@ -0,0 +1,61 @@ +// +// Created by Han Yin on 10/31/25. +// + +#ifndef AICHAT_LOGGING_H +#define AICHAT_LOGGING_H + +#endif //AICHAT_LOGGING_H + +#pragma once +#include <android/log.h> + +#ifndef LOG_TAG +#define LOG_TAG "ai-chat" +#endif + +#ifndef LOG_MIN_LEVEL +#if defined(NDEBUG) +#define LOG_MIN_LEVEL ANDROID_LOG_INFO +#else +#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE +#endif +#endif + +static inline int ai_should_log(int prio) { + return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL); +} + +#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE +#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0) +#else +#define LOGv(...) ((void)0) +#endif + +#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG +#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0) +#else +#define LOGd(...) ((void)0) +#endif + +#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0) +#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0) +#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0) + +static inline int android_log_prio_from_ggml(enum ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR; + case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN; + case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO; + case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG; + default: return ANDROID_LOG_DEFAULT; + } +} + +static inline void aichat_android_log_callback(enum ggml_log_level level, + const char* text, + void* /*user*/) { + const int prio = android_log_prio_from_ggml(level); + if (!ai_should_log(prio)) return; + __android_log_write(prio, LOG_TAG, text); +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt new file mode 100644 index 0000000..b72a24e --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt @@ -0,0 +1,14 @@ +package com.arm.aichat + +import android.content.Context +import com.arm.aichat.internal.InferenceEngineImpl + +/** + * Main entry point for Arm's AI Chat library. + */ +object AiChat { + /** + * Get the inference engine single instance. + */ + fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context) +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt new file mode 100644 index 0000000..26c1668 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt @@ -0,0 +1,89 @@ +package com.arm.aichat + +import com.arm.aichat.InferenceEngine.State +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.StateFlow + +/** + * Interface defining the core LLM inference operations. + */ +interface InferenceEngine { + /** + * Current state of the inference engine + */ + val state: StateFlow<State> + + /** + * Load a model from the given path. + * + * @throws UnsupportedArchitectureException if model architecture not supported + */ + suspend fun loadModel(pathToModel: String) + + /** + * Sends a system prompt to the loaded model + */ + suspend fun setSystemPrompt(systemPrompt: String) + + /** + * Sends a user prompt to the loaded model and returns a Flow of generated tokens. + */ + fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> + + /** + * Runs a benchmark with the specified parameters. + */ + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String + + /** + * Unloads the currently loaded model. + */ + fun cleanUp() + + /** + * Cleans up resources when the engine is no longer needed. + */ + fun destroy() + + /** + * States of the inference engine + */ + sealed class State { + object Uninitialized : State() + object Initializing : State() + object Initialized : State() + + object LoadingModel : State() + object UnloadingModel : State() + object ModelReady : State() + + object Benchmarking : State() + object ProcessingSystemPrompt : State() + object ProcessingUserPrompt : State() + + object Generating : State() + + data class Error(val exception: Exception) : State() + } + + companion object { + const val DEFAULT_PREDICT_LENGTH = 1024 + } +} + +val State.isUninterruptible + get() = this is State.Initializing || + this is State.LoadingModel || + this is State.UnloadingModel || + this is State.Benchmarking || + this is State.ProcessingSystemPrompt || + this is State.ProcessingUserPrompt + +val State.isModelLoaded: Boolean + get() = this is State.ModelReady || + this is State.Benchmarking || + this is State.ProcessingSystemPrompt || + this is State.ProcessingUserPrompt || + this is State.Generating + +class UnsupportedArchitectureException : Exception() diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt new file mode 100644 index 0000000..2f15eef --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt @@ -0,0 +1,61 @@ +package com.arm.aichat.gguf + +import kotlin.collections.get + + +/** + * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`). + * The `label` matches what llama‑cli prints. + */ +enum class FileType(val code: Int, val label: String) { + ALL_F32(0, "all F32"), + MOSTLY_F16(1, "F16"), + MOSTLY_Q4_0(2, "Q4_0"), + MOSTLY_Q4_1(3, "Q4_1"), + // 4 removed + MOSTLY_Q8_0(7, "Q8_0"), + MOSTLY_Q5_0(8, "Q5_0"), + MOSTLY_Q5_1(9, "Q5_1"), + + /* K‑quants ------------------------------------------------------------ */ + MOSTLY_Q2_K (10, "Q2_K - Medium"), + MOSTLY_Q3_K_S (11, "Q3_K - Small"), + MOSTLY_Q3_K_M (12, "Q3_K - Medium"), + MOSTLY_Q3_K_L (13, "Q3_K - Large"), + MOSTLY_Q4_K_S (14, "Q4_K - Small"), + MOSTLY_Q4_K_M (15, "Q4_K - Medium"), + MOSTLY_Q5_K_S (16, "Q5_K - Small"), + MOSTLY_Q5_K_M (17, "Q5_K - Medium"), + MOSTLY_Q6_K (18, "Q6_K"), + + /* IQ quants ----------------------------------------------------------- */ + MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"), + MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"), + MOSTLY_Q2_K_S (21, "Q2_K - Small"), + MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"), + MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"), + MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"), + MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"), + MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"), + MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"), + MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"), + MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"), + MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"), + MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"), + + /* BF16 & Ternary ------------------------------------------------------ */ + MOSTLY_BF16 (32, "BF16"), + MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"), + MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"), + + /* Special flag -------------------------------------------------------- */ + GUESSED(1024, "(guessed)"), + + UNKNOWN(-1, "unknown"); + + companion object { + private val map = entries.associateBy(FileType::code) + + fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt new file mode 100644 index 0000000..5e1971a --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt @@ -0,0 +1,132 @@ +package com.arm.aichat.gguf + +import java.io.IOException + + +/** + * Structured metadata of GGUF + */ +data class GgufMetadata( + // Basic file info + val version: GgufVersion, + val tensorCount: Long, + val kvCount: Long, + + // General info + val basic: BasicInfo, + val author: AuthorInfo? = null, + val additional: AdditionalInfo? = null, + val architecture: ArchitectureInfo? = null, + val baseModels: List<BaseModelInfo>? = null, + val tokenizer: TokenizerInfo? = null, + + // Derivative info + val dimensions: DimensionsInfo? = null, + val attention: AttentionInfo? = null, + val rope: RopeInfo? = null, + val experts: ExpertsInfo? = null +) { + enum class GgufVersion(val code: Int, val label: String) { + /** First public draft; little‑endian only, no alignment key. */ + LEGACY_V1(1, "Legacy v1"), + + /** Added split‑file support and some extra metadata keys. */ + EXTENDED_V2(2, "Extended v2"), + + /** Current spec: endian‑aware, mandatory alignment, fully validated. */ + VALIDATED_V3(3, "Validated v3"); + + companion object { + fun fromCode(code: Int): GgufVersion = + entries.firstOrNull { it.code == code } + ?: throw IOException("Unknown GGUF version code $code") + } + + override fun toString(): String = "$label (code=$code)" + } + + data class BasicInfo( + val uuid: String? = null, + val name: String? = null, + val nameLabel: String? = null, + val sizeLabel: String? = null, // Size label like "7B" + ) + + data class AuthorInfo( + val organization: String? = null, + val author: String? = null, + val doi: String? = null, + val url: String? = null, + val repoUrl: String? = null, + val license: String? = null, + val licenseLink: String? = null, + ) + + data class AdditionalInfo( + val type: String? = null, + val description: String? = null, + val tags: List<String>? = null, + val languages: List<String>? = null, + ) + + data class ArchitectureInfo( + val architecture: String? = null, + val fileType: Int? = null, + val vocabSize: Int? = null, + val finetune: String? = null, + val quantizationVersion: Int? = null, + ) + + data class BaseModelInfo( + val name: String? = null, + val author: String? = null, + val version: String? = null, + val organization: String? = null, + val url: String? = null, + val doi: String? = null, + val uuid: String? = null, + val repoUrl: String? = null, + ) + + data class TokenizerInfo( + val model: String? = null, + val bosTokenId: Int? = null, + val eosTokenId: Int? = null, + val unknownTokenId: Int? = null, + val paddingTokenId: Int? = null, + val addBosToken: Boolean? = null, + val addEosToken: Boolean? = null, + val chatTemplate: String? = null, + ) + + data class DimensionsInfo( + val contextLength: Int? = null, + val embeddingSize: Int? = null, + val blockCount: Int? = null, + val feedForwardSize: Int? = null, + ) + + data class AttentionInfo( + val headCount: Int? = null, + val headCountKv: Int? = null, + val keyLength: Int? = null, + val valueLength: Int? = null, + val layerNormEpsilon: Float? = null, + val layerNormRmsEpsilon: Float? = null, + ) + + data class RopeInfo( + val frequencyBase: Float? = null, + val dimensionCount: Int? = null, + val scalingType: String? = null, + val scalingFactor: Float? = null, + val attnFactor: Float? = null, + val originalContextLength: Int? = null, + val finetuned: Boolean? = null, + ) + + data class ExpertsInfo( + val count: Int? = null, + val usedCount: Int? = null, + ) +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt new file mode 100644 index 0000000..264a6c0 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt @@ -0,0 +1,77 @@ +package com.arm.aichat.gguf + +import android.content.Context +import android.net.Uri +import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl +import java.io.File +import java.io.IOException +import java.io.InputStream + +/** + * Interface for reading GGUF metadata from model files. + * Use `GgufMetadataReader.create()` to get an instance. + */ +interface GgufMetadataReader { + /** + * Reads the magic number from the specified file path. + * + * @param file Java File to the GGUF file with absolute path + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(file: File): Boolean + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining [android.content.ContentProvider] + * @param uri Uri to the GGUF file provided by [android.content.ContentProvider] + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean + + /** + * Reads and parses GGUF metadata from the specified file path. + * + * @param input the [InputStream] obtained from a readable file or content + * @return Structured metadata extracted from the file + * @throws IOException if file is damaged or cannot be read + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun readStructuredMetadata(input: InputStream): GgufMetadata + + companion object { + private val DEFAULT_SKIP_KEYS = setOf( + "tokenizer.chat_template", + "tokenizer.ggml.scores", + "tokenizer.ggml.tokens", + "tokenizer.ggml.token_type" + ) + + /** + * Creates a default GgufMetadataReader instance + */ + fun create(): GgufMetadataReader = GgufMetadataReaderImpl( + skipKeys = DEFAULT_SKIP_KEYS, + arraySummariseThreshold = 1_000 + ) + + /** + * Creates a GgufMetadataReader with custom configuration + * + * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map) + * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised; + * If -1, never summarise. + */ + fun create( + skipKeys: Set<String> = DEFAULT_SKIP_KEYS, + arraySummariseThreshold: Int = 1_000 + ): GgufMetadataReader = GgufMetadataReaderImpl( + skipKeys = skipKeys, + arraySummariseThreshold = arraySummariseThreshold + ) + } +} + +class InvalidFileFormatException : IOException() diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt new file mode 100644 index 0000000..a293796 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt @@ -0,0 +1,324 @@ +package com.arm.aichat.internal + +import android.content.Context +import android.util.Log +import com.arm.aichat.InferenceEngine +import com.arm.aichat.UnsupportedArchitectureException +import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance +import dalvik.annotation.optimization.FastNative +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import java.io.File +import java.io.IOException + +/** + * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. + * + * This class implements a singleton pattern for managing the lifecycle of a single LLM instance. + * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety + * with the underlying C++ native code. + * + * The typical usage flow is: + * 1. Get instance via [getInstance] + * 2. Load a model with [loadModel] + * 3. Send prompts with [sendUserPrompt] + * 4. Generate responses as token streams + * 5. Perform [cleanUp] when done with a model + * 6. Properly [destroy] when completely done + * + * State transitions are managed automatically and validated at each operation. + * + * @see ai_chat.cpp for the native implementation details + */ +internal class InferenceEngineImpl private constructor( + private val nativeLibDir: String +) : InferenceEngine { + + companion object { + private val TAG = InferenceEngineImpl::class.java.simpleName + + @Volatile + private var instance: InferenceEngine? = null + + /** + * Create or obtain [InferenceEngineImpl]'s single instance. + * + * @param Context for obtaining native library directory + * @throws IllegalArgumentException if native library path is invalid + * @throws UnsatisfiedLinkError if library failed to load + */ + internal fun getInstance(context: Context) = + instance ?: synchronized(this) { + val nativeLibDir = context.applicationInfo.nativeLibraryDir + require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" } + + try { + Log.i(TAG, "Instantiating InferenceEngineImpl,,,") + InferenceEngineImpl(nativeLibDir).also { instance = it } + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load native library from $nativeLibDir", e) + throw e + } + } + } + + /** + * JNI methods + * @see ai_chat.cpp + */ + @FastNative + private external fun init(nativeLibDir: String) + + @FastNative + private external fun load(modelPath: String): Int + + @FastNative + private external fun prepare(): Int + + @FastNative + private external fun systemInfo(): String + + @FastNative + private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String + + @FastNative + private external fun processSystemPrompt(systemPrompt: String): Int + + @FastNative + private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int + + @FastNative + private external fun generateNextToken(): String? + + @FastNative + private external fun unload() + + @FastNative + private external fun shutdown() + + private val _state = + MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized) + override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow() + + private var _readyForSystemPrompt = false + @Volatile + private var _cancelGeneration = false + + /** + * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations + */ + @OptIn(ExperimentalCoroutinesApi::class) + private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1) + private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob()) + + init { + llamaScope.launch { + try { + check(_state.value is InferenceEngine.State.Uninitialized) { + "Cannot load native library in ${_state.value.javaClass.simpleName}!" + } + _state.value = InferenceEngine.State.Initializing + Log.i(TAG, "Loading native library...") + System.loadLibrary("ai-chat") + init(nativeLibDir) + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") + + } catch (e: Exception) { + Log.e(TAG, "Failed to load native library", e) + throw e + } + } + } + + /** + * Load the LLM + */ + override suspend fun loadModel(pathToModel: String) = + withContext(llamaDispatcher) { + check(_state.value is InferenceEngine.State.Initialized) { + "Cannot load model in ${_state.value.javaClass.simpleName}!" + } + + try { + Log.i(TAG, "Checking access to model file... \n$pathToModel") + File(pathToModel).let { + require(it.exists()) { "File not found" } + require(it.isFile) { "Not a valid file" } + require(it.canRead()) { "Cannot read file" } + } + + Log.i(TAG, "Loading model... \n$pathToModel") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.LoadingModel + load(pathToModel).let { + // TODO-han.yin: find a better way to pass other error codes + if (it != 0) throw UnsupportedArchitectureException() + } + prepare().let { + if (it != 0) throw IOException("Failed to prepare resources") + } + Log.i(TAG, "Model loaded!") + _readyForSystemPrompt = true + + _cancelGeneration = false + _state.value = InferenceEngine.State.ModelReady + } catch (e: Exception) { + Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e) + _state.value = InferenceEngine.State.Error(e) + throw e + } + } + + /** + * Process the plain text system prompt + * + * TODO-han.yin: return error code if system prompt not correct processed? + */ + override suspend fun setSystemPrompt(prompt: String) = + withContext(llamaDispatcher) { + require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } + check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } + check(_state.value is InferenceEngine.State.ModelReady) { + "Cannot process system prompt in ${_state.value.javaClass.simpleName}!" + } + + Log.i(TAG, "Sending system prompt...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.ProcessingSystemPrompt + processSystemPrompt(prompt).let { result -> + if (result != 0) { + RuntimeException("Failed to process system prompt: $result").also { + _state.value = InferenceEngine.State.Error(it) + throw it + } + } + } + Log.i(TAG, "System prompt processed! Awaiting user prompt...") + _state.value = InferenceEngine.State.ModelReady + } + + /** + * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] + */ + override fun sendUserPrompt( + message: String, + predictLength: Int, + ): Flow<String> = flow { + require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } + check(_state.value is InferenceEngine.State.ModelReady) { + "User prompt discarded due to: ${_state.value.javaClass.simpleName}" + } + + try { + Log.i(TAG, "Sending user prompt...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.ProcessingUserPrompt + + processUserPrompt(message, predictLength).let { result -> + if (result != 0) { + Log.e(TAG, "Failed to process user prompt: $result") + return@flow + } + } + + Log.i(TAG, "User prompt processed. Generating assistant prompt...") + _state.value = InferenceEngine.State.Generating + while (!_cancelGeneration) { + generateNextToken()?.let { utf8token -> + if (utf8token.isNotEmpty()) emit(utf8token) + } ?: break + } + if (_cancelGeneration) { + Log.i(TAG, "Assistant generation aborted per requested.") + } else { + Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") + } + _state.value = InferenceEngine.State.ModelReady + } catch (e: CancellationException) { + Log.i(TAG, "Assistant generation's flow collection cancelled.") + _state.value = InferenceEngine.State.ModelReady + throw e + } catch (e: Exception) { + Log.e(TAG, "Error during generation!", e) + _state.value = InferenceEngine.State.Error(e) + throw e + } + }.flowOn(llamaDispatcher) + + /** + * Benchmark the model + */ + override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = + withContext(llamaDispatcher) { + check(_state.value is InferenceEngine.State.ModelReady) { + "Benchmark request discarded due to: $state" + } + Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") + _readyForSystemPrompt = false // Just to be safe + _state.value = InferenceEngine.State.Benchmarking + benchModel(pp, tg, pl, nr).also { + _state.value = InferenceEngine.State.ModelReady + } + } + + /** + * Unloads the model and frees resources, or reset error states + */ + override fun cleanUp() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { + when (val state = _state.value) { + is InferenceEngine.State.ModelReady -> { + Log.i(TAG, "Unloading model and free resources...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.UnloadingModel + + unload() + + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "Model unloaded!") + Unit + } + + is InferenceEngine.State.Error -> { + Log.i(TAG, "Resetting error states...") + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "States reset!") + Unit + } + + else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") + } + } + } + + /** + * Cancel all ongoing coroutines and free GGML backends + */ + override fun destroy() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { + _readyForSystemPrompt = false + when(_state.value) { + is InferenceEngine.State.Uninitialized -> {} + is InferenceEngine.State.Initialized -> shutdown() + else -> { unload(); shutdown() } + } + } + llamaScope.cancel() + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt new file mode 100644 index 0000000..bf250ac --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt @@ -0,0 +1,590 @@ +package com.arm.aichat.internal.gguf + +import android.content.Context +import android.net.Uri +import com.arm.aichat.gguf.GgufMetadata +import com.arm.aichat.gguf.GgufMetadataReader +import com.arm.aichat.gguf.InvalidFileFormatException +import java.io.File +import java.io.IOException +import java.io.InputStream + + +/** + * Utility class to read GGUF model files and extract metadata key-value pairs. + * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data. + */ +internal class GgufMetadataReaderImpl( + private val skipKeys: Set<String>, + private val arraySummariseThreshold: Int, +) : GgufMetadataReader { + companion object { + private const val ARCH_LLAMA = "llama" + } + + /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */ + enum class MetadataType(val code: Int) { + UINT8(0), INT8(1), UINT16(2), INT16(3), + UINT32(4), INT32(5), FLOAT32(6), BOOL(7), + STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); + companion object { + private val codeMap = entries.associateBy(MetadataType::code) + fun fromCode(code: Int): MetadataType = codeMap[code] + ?: throw IOException("Unknown metadata value type code: $code") + } + } + + /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */ + sealed class MetadataValue { + data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int + data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int + data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian) + data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian) + data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian) + data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian) + data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float + data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true) + data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed) + data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue() + data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian) + data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian) + data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double + } + + /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */ + private fun MetadataValue.toPrimitive(): Any = when (this) { + is MetadataValue.UInt8 -> value + is MetadataValue.Int8 -> value + is MetadataValue.UInt16 -> value + is MetadataValue.Int16 -> value + is MetadataValue.UInt32 -> value + is MetadataValue.Int32 -> value + is MetadataValue.Float32 -> value + is MetadataValue.Bool -> value + is MetadataValue.StringVal -> value + is MetadataValue.UInt64 -> value + is MetadataValue.Int64 -> value + is MetadataValue.Float64 -> value + is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } + } + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(file: File): Boolean = + file.inputStream().buffered().use { ensureMagic(it) } + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean = + context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true + + /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */ + private fun ensureMagic(input: InputStream): Boolean = + ByteArray(4).let { + if (input.read(it) != 4) throw IOException("Not a valid file!") + it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" + } + + /** + * High‑level entry point: parses a `.gguf` file on disk and returns the fully + * populated [GgufMetadata] tree. + * + * Steps performed internally: + * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version). + * 2. Streams through the key‑value section, skipping large blobs if the key + * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold]. + * 3. Converts the resulting raw map into strongly‑typed sub‑structures + * (basic info, tokenizer, rope, etc.). + * + * The method is STREAMING‑ONLY: tensors are never mapped or loaded into + * memory, so even multi‑GB model files can be processed in < 50 ms. + * + * @param path Absolute or relative filesystem path to a `.gguf` file. + * @return A [GgufMetadata] instance containing all recognised metadata plus + * an `allMetadata` map with any keys that were not given a dedicated + * field. + * @throws IOException if the file is not GGUF, the version is unsupported, + * or the metadata block is truncated / corrupt. + */ + override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata { + // ── 1. header ────────────────────────────────────────────────────────── + // throws on mismatch + val version = ensureMagicAndVersion(input) + val tensorCount = readLittleLong(input) + val kvCount = readLittleLong(input) + + // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── + val meta = readMetaMap(input, kvCount) // <String, MetadataValue> + + // ── 3. build structured object ──────────────────────────────────────── + return buildStructured(meta, version, tensorCount, kvCount) + } + + /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ + private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { + if (!ensureMagic(input)) throw InvalidFileFormatException() + return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) + } + + /** + * Read an unsigned 32‑bit little‑endian integer. + * + * @throws IOException if fewer than four bytes are available. + */ + private fun readLEUInt32(input: InputStream): Int { + val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() + if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") + return (b3 and 0xFF shl 24) or + (b2 and 0xFF shl 16) or + (b1 and 0xFF shl 8) or + (b0 and 0xFF) + } + + /** + * Low‑level helper that reads the entire “key-value” section from the current + * stream position. + * + * @param input Open stream positioned JUST AFTER the header. + * @param kvCnt Number of key‑value pairs (taken from the header). + * @return Mutable map with one [MetadataValue] for every key that is NOT skipped. + * + * The function honours [skipKeys] and [arraySummariseThreshold] by invoking + * [skipValue] or [parseValue] accordingly. + */ + private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> = + mutableMapOf<String, MetadataValue>().apply { + repeat(kvCnt.toInt()) { + val key = readString(input) + val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + if (key in skipKeys) { + skipValue(input, valueT) + } else { + this[key] = parseValue(input, valueT) + } + } + } + + /** + * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed + * [GgufMetadata] tree used by the rest of the app. + * + * Only the keys listed in the spec are copied into dedicated data classes; + * everything else is preserved in `GgufMetadata.allMetadata`. + * + * @param m Raw key/value map. + * @param version GGUF file‑format version (enum). + * @param tensorCnt Number of tensors (from the header). + * @param kvCnt Total metadata pair count (from the header). + */ + private fun buildStructured( + m: Map<String, MetadataValue>, + version: GgufMetadata.GgufVersion, + tensorCnt: Long, + kvCnt: Long + ): GgufMetadata { + // ---------- helpers ---------- + fun String.str() = (m[this] as? MetadataValue.StringVal)?.value + fun String.bool() = (m[this] as? MetadataValue.Bool)?.value + fun String.i32() = (m[this] as? MetadataValue.Int32)?.value + fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt() + fun String.f32() = (m[this] as? MetadataValue.Float32)?.value + fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat() + fun String.strList(): List<String>? = + (m[this] as? MetadataValue.ArrayVal) + ?.elements + ?.mapNotNull { (it as? MetadataValue.StringVal)?.value } + + val arch = "general.architecture".str() ?: ARCH_LLAMA + + // -------------- populate sections ---------------- + val basic = GgufMetadata.BasicInfo( + uuid = "general.uuid".str(), + name = "general.basename".str(), + nameLabel = "general.name".str(), + sizeLabel = "general.size_label".str() + ) + + val author = GgufMetadata.AuthorInfo( + organization = "general.organization".str(), + author = "general.author".str(), + doi = "general.doi".str(), + url = "general.url".str(), + repoUrl = "general.repo_url".str(), + license = "general.license".str(), + licenseLink = "general.license.link".str() + ).takeUnless { + organization == null && author == null && doi == null && + url == null && repoUrl == null && license == null && licenseLink == null + } + + val additional = GgufMetadata.AdditionalInfo( + type = "general.type".str(), + description = "general.description".str(), + tags = "general.tags".strList(), + languages = "general.languages".strList() + ).takeUnless { + type == null && description == null && tags == null && languages == null + } + + val architectureInfo = GgufMetadata.ArchitectureInfo( + architecture = arch, + fileType = "general.file_type".u32(), + vocabSize = "$arch.vocab_size".u32(), + finetune = "general.finetune".str(), + quantizationVersion = "general.quantization_version".u32() + ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null } + + val baseModels = buildList { + val n = "general.base_model.count".u32() ?: 0 + for (i in 0 until n) { + fun k(s: String) = "general.base_model.$i.$s" + add( + GgufMetadata.BaseModelInfo( + name = k("name").str(), + author = k("author").str(), + version = k("version").str(), + organization = k("organization").str(), + url = k("url").str(), + doi = k("doi").str(), + uuid = k("uuid").str(), + repoUrl = k("repo_url").str(), + ) + ) + } + }.takeIf { it.isNotEmpty() } + + val tokenizer = GgufMetadata.TokenizerInfo( + model = "tokenizer.ggml.model".str(), + bosTokenId = "tokenizer.ggml.bos_token_id".u32(), + eosTokenId = "tokenizer.ggml.eos_token_id".u32(), + unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(), + paddingTokenId = "tokenizer.ggml.padding_token_id".u32(), + addBosToken = "tokenizer.ggml.add_bos_token".bool(), + addEosToken = "tokenizer.ggml.add_eos_token".bool(), + chatTemplate = "tokenizer.chat_template".str() + ).takeUnless { model == null && bosTokenId == null && eosTokenId == null && + unknownTokenId == null && paddingTokenId == null && + addBosToken == null && addEosToken == null && chatTemplate == null + } + + val dimensions = GgufMetadata.DimensionsInfo( + contextLength = "$arch.context_length".u32(), + embeddingSize = "$arch.embedding_length".u32(), + blockCount = "$arch.block_count".u32(), + feedForwardSize = "$arch.feed_forward_length".u32() + ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null } + + val attention = GgufMetadata.AttentionInfo( + headCount = "$arch.attention.head_count".u32(), + headCountKv = "$arch.attention.head_count_kv".u32(), + keyLength = "$arch.attention.key_length".u32(), + valueLength = "$arch.attention.value_length".u32(), + layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(), + layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(), + ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null && + layerNormEpsilon == null && layerNormRmsEpsilon == null + } + + val rope = GgufMetadata.RopeInfo( + frequencyBase = "$arch.rope.freq_base".f32(), + dimensionCount = "$arch.rope.dimension_count".u32(), + scalingType = "$arch.rope.scaling.type".str(), + scalingFactor = "$arch.rope.scaling.factor".f32(), + attnFactor = "$arch.rope.scaling.attn_factor".f32(), + originalContextLength = "$arch.rope.scaling.original_context_length".u32(), + finetuned = "$arch.rope.scaling.finetuned".bool() + ).takeUnless { frequencyBase == null && dimensionCount == null && + scalingType == null && scalingFactor == null && attnFactor == null && + originalContextLength == null && finetuned == null + } + + val experts = GgufMetadata.ExpertsInfo( + count = "$arch.expert_count".u32(), + usedCount = "$arch.expert_used_count".u32() + ).takeUnless { count == null && usedCount == null } + + return GgufMetadata( + version = version, + tensorCount = tensorCnt, + kvCount = kvCnt, + basic = basic, + author = author, + additional = additional, + architecture = architectureInfo, + baseModels = baseModels, + tokenizer = tokenizer, + dimensions = dimensions, + attention = attention, + rope = rope, + experts = experts + ) + } + + /** + * Recursively parses a metadata value of the given type from the input stream. + * @param input The input stream positioned at the start of the value. + * @param type The metadata value type to parse. + */ + private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) { + MetadataType.UINT8 -> { + // 1-byte unsigned integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.") + MetadataValue.UInt8(byteVal.toUByte()) + } + MetadataType.INT8 -> { + // 1-byte signed integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.") + MetadataValue.Int8(byteVal.toByte()) + } + MetadataType.UINT16 -> { + // 2-byte unsigned integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.") + // Combine two bytes (little-endian) into an unsigned 16-bit value + val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.UInt16(u16.toUShort()) + } + MetadataType.INT16 -> { + // 2-byte signed integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.") + // Combine to 16-bit and interpret as signed + val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.Int16(i16.toShort()) + } + MetadataType.UINT32 -> { + // 4-byte unsigned integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.") + // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt + val u32 = (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.UInt32(u32.toUInt()) + } + MetadataType.INT32 -> { + // 4-byte signed integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.") + // Combine four bytes into a 32-bit signed int + val i32 = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + MetadataValue.Int32(i32) + } + MetadataType.FLOAT32 -> { + // 4-byte IEEE 754 float (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.") + // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float + val bits = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + val floatVal = Float.fromBits(bits) + MetadataValue.Float32(floatVal) + } + MetadataType.BOOL -> { + // 1-byte boolean (0 = false, 1 = true) + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.") + if (byteVal != 0 && byteVal != 1) { + throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).") + } + MetadataValue.Bool(byteVal != 0) + } + MetadataType.STRING -> { + // UTF-8 string (length-prefixed with 8-byte length) + val str = readString(input) + MetadataValue.StringVal(str) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + val count = len.toInt() + + if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) { + // fast‑forward without allocation + repeat(count) { skipValue(input, elemType) } + MetadataValue.StringVal("Array($elemType, $count items) /* summarised */") + } else { + val list = ArrayList<MetadataValue>(count) + repeat(count) { list += parseValue(input, elemType) } + MetadataValue.ArrayVal(elemType, list) + } + } + MetadataType.UINT64 -> { + // 8-byte unsigned integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.") + // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range. + val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or + (bytes[6].toULong() and 0xFFuL shl 48) or + (bytes[5].toULong() and 0xFFuL shl 40) or + (bytes[4].toULong() and 0xFFuL shl 32) or + (bytes[3].toULong() and 0xFFuL shl 24) or + (bytes[2].toULong() and 0xFFuL shl 16) or + (bytes[1].toULong() and 0xFFuL shl 8) or + (bytes[0].toULong() and 0xFFuL) + MetadataValue.UInt64(u64) + } + MetadataType.INT64 -> { + // 8-byte signed integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.") + // Combine 8 bytes into a signed 64-bit value (Long) + val i64 = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.Int64(i64) + } + MetadataType.FLOAT64 -> { + // 8-byte IEEE 754 double (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.") + // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double + val bits = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + val doubleVal = Double.fromBits(bits) + MetadataValue.Float64(doubleVal) + } + } + + + private fun <T> T?.takeUnless(check: T.() -> Boolean): T? = + this?.takeIf { !it.check() } + + /** Helper: Skip a value in the stream without storing it (still maintains pointer). */ + private fun skipValue(input: InputStream, type: MetadataType) { + when (type) { + MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1) + MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2) + MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4) + MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8) + MetadataType.STRING -> { + val len = readLittleLong(input); input.skipFully(len) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip + } + } + } + + /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */ + private fun readLittleLong(input: InputStream): Long { + val bytes = ByteArray(8) + input.readFully(bytes) + + // Combine 8 bytes into a 64-bit value (Little Endian). + // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement). + // In our context (lengths/counts), such extremely large values are not expected. + return (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + } + + /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ + private fun readString(input: InputStream): String = + // Read 8-byte little-endian length (number of bytes in the string). + readLittleLong(input).let { len -> + if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") + + // Read the UTF-8 bytes of the given length. + ByteArray(len.toInt()).let { + if (it.isNotEmpty()) input.readFully(it) + String(it, Charsets.UTF_8) + } + } + + /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ + private fun littleEndianBytesToInt(bytes: ByteArray): Int = + // Note: assumes bytes length is 4. + (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + + /** + * Robust skip that works the same on JDK 11 and Android’s desugared runtime. + * + * @param n Number of bytes to advance in the stream. + * @throws IOException on premature EOF. + */ + private fun InputStream.skipFully(n: Long) { + var remaining = n + val scratch = ByteArray(8192) // read‑and‑toss buffer + while (remaining > 0) { + val skipped = skip(remaining) + when { + skipped > 0 -> remaining -= skipped // normal fast path + skipped == 0L -> { + // fallback: read and discard + val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt()) + if (read == -1) throw IOException("EOF while skipping $n bytes") + remaining -= read + } + else -> throw IOException("Skip returned negative value") + } + } + } + + /** + * Extension that keeps reading until the requested number of bytes are filled. + * Falls back to `read()` when `skip()` returns 0, which happens on some Android + * streams. + * + * @param buf Destination buffer. + * @param len Number of bytes to fill (defaults to `buf.size`). + * @throws IOException on premature EOF. + */ + private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) { + var off = 0 + while (off < len) { + val n = read(buf, off, len - off) + if (n == -1) throw IOException("EOF after $off of $len bytes") + off += n + } + } + + /** + * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array. + * This is used for small fixed‑length reads (e.g. 4‑byte type codes). + * + * @throws IOException on premature EOF. + */ + private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also { + if (read(it) != n) throw IOException("Unexpected EOF") + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt b/llama.cpp/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt new file mode 100644 index 0000000..cbbb974 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt @@ -0,0 +1,17 @@ +package android.llama.cpp + +import org.junit.Test + +import org.junit.Assert.* + +/** + * Example local unit test, which will execute on the development machine (host). + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +class ExampleUnitTest { + @Test + fun addition_isCorrect() { + assertEquals(4, 2 + 2) + } +} diff --git a/llama.cpp/examples/llama.android/settings.gradle.kts b/llama.cpp/examples/llama.android/settings.gradle.kts new file mode 100644 index 0000000..74f4eb3 --- /dev/null +++ b/llama.cpp/examples/llama.android/settings.gradle.kts @@ -0,0 +1,18 @@ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + mavenCentral() + google() + } +} + +rootProject.name = "AiChat" +include(":app") +include(":lib") diff --git a/llama.cpp/examples/llama.swiftui/.gitignore b/llama.cpp/examples/llama.swiftui/.gitignore new file mode 100644 index 0000000..e585a2a --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/.gitignore @@ -0,0 +1,2 @@ +xcuserdata +xcshareddata diff --git a/llama.cpp/examples/llama.swiftui/README.md b/llama.cpp/examples/llama.swiftui/README.md new file mode 100644 index 0000000..bd7ce37 --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/README.md @@ -0,0 +1,27 @@ +# llama.cpp/examples/llama.swiftui + +Local inference of llama.cpp on an iPhone. This is a sample app that can be used as a starting +point for more advanced projects. + +For usage instructions and performance stats, check the following discussion: https://github.com/ggml-org/llama.cpp/discussions/4508 + + +### Building +First llama.cpp need to be built and a XCFramework needs to be created. This can be done by running +the following script from the llama.cpp project root: +```console +$ ./build-xcframework.sh +``` +Open `llama.swiftui.xcodeproj` project in Xcode and you should be able to build and run the app on +a simulator or a real device. + +To use the framework with a different project, the XCFramework can be added to the project by +adding `build-apple/llama.xcframework` by dragging and dropping it into the project navigator, or +by manually selecting the framework in the "Frameworks, Libraries, and Embedded Content" section +of the project settings. + + + +Video demonstration: + +https://github.com/bachittle/llama.cpp/assets/39804642/e290827a-4edb-4093-9642-2a5e399ec545 diff --git a/llama.cpp/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/llama.cpp/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift new file mode 100644 index 0000000..dc2bafc --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -0,0 +1,337 @@ +import Foundation +import llama + +enum LlamaError: Error { + case couldNotInitializeContext +} + +func llama_batch_clear(_ batch: inout llama_batch) { + batch.n_tokens = 0 +} + +func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) { + batch.token [Int(batch.n_tokens)] = id + batch.pos [Int(batch.n_tokens)] = pos + batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count) + for i in 0..<seq_ids.count { + batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i] + } + batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0 + + batch.n_tokens += 1 +} + +actor LlamaContext { + private var model: OpaquePointer + private var context: OpaquePointer + private var vocab: OpaquePointer + private var sampling: UnsafeMutablePointer<llama_sampler> + private var batch: llama_batch + private var tokens_list: [llama_token] + var is_done: Bool = false + + /// This variable is used to store temporarily invalid cchars + private var temporary_invalid_cchars: [CChar] + + var n_len: Int32 = 1024 + var n_cur: Int32 = 0 + + var n_decode: Int32 = 0 + + init(model: OpaquePointer, context: OpaquePointer) { + self.model = model + self.context = context + self.tokens_list = [] + self.batch = llama_batch_init(512, 0, 1) + self.temporary_invalid_cchars = [] + let sparams = llama_sampler_chain_default_params() + self.sampling = llama_sampler_chain_init(sparams) + llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) + llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) + vocab = llama_model_get_vocab(model) + } + + deinit { + llama_sampler_free(sampling) + llama_batch_free(batch) + llama_model_free(model) + llama_free(context) + llama_backend_free() + } + + static func create_context(path: String) throws -> LlamaContext { + llama_backend_init() + var model_params = llama_model_default_params() + +#if targetEnvironment(simulator) + model_params.n_gpu_layers = 0 + print("Running on simulator, force use n_gpu_layers = 0") +#endif + let model = llama_model_load_from_file(path, model_params) + guard let model else { + print("Could not load model at \(path)") + throw LlamaError.couldNotInitializeContext + } + + let n_threads = max(1, min(8, ProcessInfo.processInfo.processorCount - 2)) + print("Using \(n_threads) threads") + + var ctx_params = llama_context_default_params() + ctx_params.n_ctx = 2048 + ctx_params.n_threads = Int32(n_threads) + ctx_params.n_threads_batch = Int32(n_threads) + + let context = llama_init_from_model(model, ctx_params) + guard let context else { + print("Could not load context!") + throw LlamaError.couldNotInitializeContext + } + + return LlamaContext(model: model, context: context) + } + + func model_info() -> String { + let result = UnsafeMutablePointer<Int8>.allocate(capacity: 256) + result.initialize(repeating: Int8(0), count: 256) + defer { + result.deallocate() + } + + // TODO: this is probably very stupid way to get the string from C + + let nChars = llama_model_desc(model, result, 256) + let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nChars)) + + var SwiftString = "" + for char in bufferPointer { + SwiftString.append(Character(UnicodeScalar(UInt8(char)))) + } + + return SwiftString + } + + func get_n_tokens() -> Int32 { + return batch.n_tokens; + } + + func completion_init(text: String) { + print("attempting to complete \"\(text)\"") + + tokens_list = tokenize(text: text, add_bos: true) + temporary_invalid_cchars = [] + + let n_ctx = llama_n_ctx(context) + let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count) + + print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)") + + if n_kv_req > n_ctx { + print("error: n_kv_req > n_ctx, the required KV cache size is not big enough") + } + + for id in tokens_list { + print(String(cString: token_to_piece(token: id) + [0])) + } + + llama_batch_clear(&batch) + + for i1 in 0..<tokens_list.count { + let i = Int(i1) + llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false) + } + batch.logits[Int(batch.n_tokens) - 1] = 1 // true + + if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + } + + n_cur = batch.n_tokens + } + + func completion_loop() -> String { + var new_token_id: llama_token = 0 + + new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) + + if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { + print("\n") + is_done = true + let new_token_str = String(cString: temporary_invalid_cchars + [0]) + temporary_invalid_cchars.removeAll() + return new_token_str + } + + let new_token_cchars = token_to_piece(token: new_token_id) + temporary_invalid_cchars.append(contentsOf: new_token_cchars) + let new_token_str: String + if let string = String(validatingUTF8: temporary_invalid_cchars + [0]) { + temporary_invalid_cchars.removeAll() + new_token_str = string + } else if (0 ..< temporary_invalid_cchars.count).contains(where: {$0 != 0 && String(validatingUTF8: Array(temporary_invalid_cchars.suffix($0)) + [0]) != nil}) { + // in this case, at least the suffix of the temporary_invalid_cchars can be interpreted as UTF8 string + let string = String(cString: temporary_invalid_cchars + [0]) + temporary_invalid_cchars.removeAll() + new_token_str = string + } else { + new_token_str = "" + } + print(new_token_str) + // tokens_list.append(new_token_id) + + llama_batch_clear(&batch) + llama_batch_add(&batch, new_token_id, n_cur, [0], true) + + n_decode += 1 + n_cur += 1 + + if llama_decode(context, batch) != 0 { + print("failed to evaluate llama!") + } + + return new_token_str + } + + func bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) -> String { + var pp_avg: Double = 0 + var tg_avg: Double = 0 + + var pp_std: Double = 0 + var tg_std: Double = 0 + + for _ in 0..<nr { + // bench prompt processing + + llama_batch_clear(&batch) + + let n_tokens = pp + + for i in 0..<n_tokens { + llama_batch_add(&batch, 0, Int32(i), [0], false) + } + batch.logits[Int(batch.n_tokens) - 1] = 1 // true + + llama_memory_clear(llama_get_memory(context), false) + + let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000; + + if llama_decode(context, batch) != 0 { + print("llama_decode() failed during prompt") + } + llama_synchronize(context) + + let t_pp_end = DispatchTime.now().uptimeNanoseconds / 1000; + + // bench text generation + + llama_memory_clear(llama_get_memory(context), false) + + let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000; + + for i in 0..<tg { + llama_batch_clear(&batch) + + for j in 0..<pl { + llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true) + } + + if llama_decode(context, batch) != 0 { + print("llama_decode() failed during text generation") + } + llama_synchronize(context) + } + + let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000; + + llama_memory_clear(llama_get_memory(context), false) + + let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 + let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 + + let speed_pp = Double(pp) / t_pp + let speed_tg = Double(pl*tg) / t_tg + + pp_avg += speed_pp + tg_avg += speed_tg + + pp_std += speed_pp * speed_pp + tg_std += speed_tg * speed_tg + + print("pp \(speed_pp) t/s, tg \(speed_tg) t/s") + } + + pp_avg /= Double(nr) + tg_avg /= Double(nr) + + if nr > 1 { + pp_std = sqrt(pp_std / Double(nr - 1) - pp_avg * pp_avg * Double(nr) / Double(nr - 1)) + tg_std = sqrt(tg_std / Double(nr - 1) - tg_avg * tg_avg * Double(nr) / Double(nr - 1)) + } else { + pp_std = 0 + tg_std = 0 + } + + let model_desc = model_info(); + let model_size = String(format: "%.2f GiB", Double(llama_model_size(model)) / 1024.0 / 1024.0 / 1024.0); + let model_n_params = String(format: "%.2f B", Double(llama_model_n_params(model)) / 1e9); + let backend = "Metal"; + let pp_avg_str = String(format: "%.2f", pp_avg); + let tg_avg_str = String(format: "%.2f", tg_avg); + let pp_std_str = String(format: "%.2f", pp_std); + let tg_std_str = String(format: "%.2f", tg_std); + + var result = "" + + result += String("| model | size | params | backend | test | t/s |\n") + result += String("| --- | --- | --- | --- | --- | --- |\n") + result += String("| \(model_desc) | \(model_size) | \(model_n_params) | \(backend) | pp \(pp) | \(pp_avg_str) ± \(pp_std_str) |\n") + result += String("| \(model_desc) | \(model_size) | \(model_n_params) | \(backend) | tg \(tg) | \(tg_avg_str) ± \(tg_std_str) |\n") + + return result; + } + + func clear() { + tokens_list.removeAll() + temporary_invalid_cchars.removeAll() + llama_memory_clear(llama_get_memory(context), true) + } + + private func tokenize(text: String, add_bos: Bool) -> [llama_token] { + let utf8Count = text.utf8.count + let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1 + let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens) + let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false) + + var swiftTokens: [llama_token] = [] + for i in 0..<tokenCount { + swiftTokens.append(tokens[Int(i)]) + } + + tokens.deallocate() + + return swiftTokens + } + + /// - note: The result does not contain null-terminator + private func token_to_piece(token: llama_token) -> [CChar] { + let result = UnsafeMutablePointer<Int8>.allocate(capacity: 8) + result.initialize(repeating: Int8(0), count: 8) + defer { + result.deallocate() + } + let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false) + + if nTokens < 0 { + let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens)) + newResult.initialize(repeating: Int8(0), count: Int(-nTokens)) + defer { + newResult.deallocate() + } + let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false) + let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens)) + return Array(bufferPointer) + } else { + let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nTokens)) + return Array(bufferPointer) + } + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj b/llama.cpp/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj new file mode 100644 index 0000000..6f08fe2 --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui.xcodeproj/project.pbxproj @@ -0,0 +1,449 @@ +// !$*UTF8*$! +{ + archiveVersion = 1; + classes = { + }; + objectVersion = 56; + objects = { + +/* Begin PBXBuildFile section */ + 549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 549479CA2AC9E16000E0F78B /* Metal.framework */; }; + 79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */; }; + 7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */; }; + 8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */; }; + 8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A1C83782AC328BD0096AF73 /* ContentView.swift */; }; + 8A1C837B2AC328BE0096AF73 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8A1C837A2AC328BE0096AF73 /* Assets.xcassets */; }; + 8A39BE0A2AC7601100BFEB40 /* Accelerate.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 8A39BE092AC7601000BFEB40 /* Accelerate.framework */; }; + 8A3F84242AC4C891005E2EE8 /* models in Resources */ = {isa = PBXBuildFile; fileRef = 8A3F84232AC4C891005E2EE8 /* models */; }; + 8A907F332AC7138A006146EA /* LibLlama.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A907F322AC7134E006146EA /* LibLlama.swift */; }; + 8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */; }; + DD84C9FD2D747FED007778EC /* llama.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = DD84C9FC2D747FED007778EC /* llama.xcframework */; }; + DD84C9FE2D747FED007778EC /* llama.xcframework in Embed Frameworks */ = {isa = PBXBuildFile; fileRef = DD84C9FC2D747FED007778EC /* llama.xcframework */; settings = {ATTRIBUTES = (CodeSignOnCopy, RemoveHeadersOnCopy, ); }; }; + F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */ = {isa = PBXBuildFile; fileRef = F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */; }; +/* End PBXBuildFile section */ + +/* Begin PBXCopyFilesBuildPhase section */ + DD84C9FF2D747FED007778EC /* Embed Frameworks */ = { + isa = PBXCopyFilesBuildPhase; + buildActionMask = 2147483647; + dstPath = ""; + dstSubfolderSpec = 10; + files = ( + DD84C9FE2D747FED007778EC /* llama.xcframework in Embed Frameworks */, + ); + name = "Embed Frameworks"; + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXCopyFilesBuildPhase section */ + +/* Begin PBXFileReference section */ + 549479CA2AC9E16000E0F78B /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; }; + 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = InputButton.swift; sourceTree = "<group>"; }; + 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = DownloadButton.swift; sourceTree = "<group>"; }; + 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = llama.swiftui.app; sourceTree = BUILT_PRODUCTS_DIR; }; + 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = llama_swiftuiApp.swift; sourceTree = "<group>"; }; + 8A1C83782AC328BD0096AF73 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = "<group>"; }; + 8A1C837A2AC328BE0096AF73 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = "<group>"; }; + 8A39BE092AC7601000BFEB40 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; + 8A3F84232AC4C891005E2EE8 /* models */ = {isa = PBXFileReference; lastKnownFileType = folder; name = models; path = llama.swiftui/Resources/models; sourceTree = "<group>"; }; + 8A907F322AC7134E006146EA /* LibLlama.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LibLlama.swift; sourceTree = "<group>"; }; + 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LlamaState.swift; sourceTree = "<group>"; }; + DD84C9FC2D747FED007778EC /* llama.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = llama.xcframework; path = "../../build-apple/llama.xcframework"; sourceTree = "<group>"; }; + DF2D2FE72B4A59BE00FCB72D /* llama.cpp */ = {isa = PBXFileReference; lastKnownFileType = wrapper; name = llama.cpp; path = ../..; sourceTree = "<group>"; }; + F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = LoadCustomButton.swift; sourceTree = "<group>"; }; +/* End PBXFileReference section */ + +/* Begin PBXFrameworksBuildPhase section */ + 8A1C83702AC328BD0096AF73 /* Frameworks */ = { + isa = PBXFrameworksBuildPhase; + buildActionMask = 2147483647; + files = ( + 549479CB2AC9E16000E0F78B /* Metal.framework in Frameworks */, + 8A39BE0A2AC7601100BFEB40 /* Accelerate.framework in Frameworks */, + DD84C9FD2D747FED007778EC /* llama.xcframework in Frameworks */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXFrameworksBuildPhase section */ + +/* Begin PBXGroup section */ + 8A1C836A2AC328BD0096AF73 = { + isa = PBXGroup; + children = ( + DF2D2FE72B4A59BE00FCB72D /* llama.cpp */, + 8A907F312AC7134E006146EA /* llama.cpp.swift */, + 8A3F84232AC4C891005E2EE8 /* models */, + 8A1C83752AC328BD0096AF73 /* llama.swiftui */, + 8A1C83742AC328BD0096AF73 /* Products */, + 8A39BE082AC7601000BFEB40 /* Frameworks */, + ); + sourceTree = "<group>"; + }; + 8A1C83742AC328BD0096AF73 /* Products */ = { + isa = PBXGroup; + children = ( + 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */, + ); + name = Products; + sourceTree = "<group>"; + }; + 8A1C83752AC328BD0096AF73 /* llama.swiftui */ = { + isa = PBXGroup; + children = ( + 8A3F84102AC4BD85005E2EE8 /* Resources */, + 8A9F7C4B2AC332DC008AE1EA /* Models */, + 8A9F7C4A2AC332BF008AE1EA /* UI */, + 8A1C83762AC328BD0096AF73 /* llama_swiftuiApp.swift */, + 8A1C837A2AC328BE0096AF73 /* Assets.xcassets */, + ); + path = llama.swiftui; + sourceTree = "<group>"; + }; + 8A39BE082AC7601000BFEB40 /* Frameworks */ = { + isa = PBXGroup; + children = ( + DD84C9FC2D747FED007778EC /* llama.xcframework */, + 549479CA2AC9E16000E0F78B /* Metal.framework */, + 8A39BE092AC7601000BFEB40 /* Accelerate.framework */, + ); + name = Frameworks; + sourceTree = "<group>"; + }; + 8A3F84102AC4BD85005E2EE8 /* Resources */ = { + isa = PBXGroup; + children = ( + 8A3F84112AC4BD8C005E2EE8 /* models */, + ); + path = Resources; + sourceTree = "<group>"; + }; + 8A3F84112AC4BD8C005E2EE8 /* models */ = { + isa = PBXGroup; + children = ( + ); + path = models; + sourceTree = "<group>"; + }; + 8A907F312AC7134E006146EA /* llama.cpp.swift */ = { + isa = PBXGroup; + children = ( + 8A907F322AC7134E006146EA /* LibLlama.swift */, + ); + path = llama.cpp.swift; + sourceTree = "<group>"; + }; + 8A9F7C4A2AC332BF008AE1EA /* UI */ = { + isa = PBXGroup; + children = ( + 7FA3D2B22B2EA2F600543F92 /* DownloadButton.swift */, + 8A1C83782AC328BD0096AF73 /* ContentView.swift */, + F1FE20E12B465EC900B45541 /* LoadCustomButton.swift */, + 79E1D9CC2B4CD16E005F8E46 /* InputButton.swift */, + ); + path = UI; + sourceTree = "<group>"; + }; + 8A9F7C4B2AC332DC008AE1EA /* Models */ = { + isa = PBXGroup; + children = ( + 8A9F7C4C2AC332EE008AE1EA /* LlamaState.swift */, + ); + path = Models; + sourceTree = "<group>"; + }; +/* End PBXGroup section */ + +/* Begin PBXNativeTarget section */ + 8A1C83722AC328BD0096AF73 /* llama.swiftui */ = { + isa = PBXNativeTarget; + buildConfigurationList = 8A1C83812AC328BE0096AF73 /* Build configuration list for PBXNativeTarget "llama.swiftui" */; + buildPhases = ( + 8A1C836F2AC328BD0096AF73 /* Sources */, + 8A1C83702AC328BD0096AF73 /* Frameworks */, + 8A1C83712AC328BD0096AF73 /* Resources */, + DD84C9FF2D747FED007778EC /* Embed Frameworks */, + ); + buildRules = ( + ); + dependencies = ( + ); + name = llama.swiftui; + packageProductDependencies = ( + ); + productName = llama.swiftui; + productReference = 8A1C83732AC328BD0096AF73 /* llama.swiftui.app */; + productType = "com.apple.product-type.application"; + }; +/* End PBXNativeTarget section */ + +/* Begin PBXProject section */ + 8A1C836B2AC328BD0096AF73 /* Project object */ = { + isa = PBXProject; + attributes = { + BuildIndependentTargetsInParallel = 1; + LastSwiftUpdateCheck = 1500; + LastUpgradeCheck = 1500; + TargetAttributes = { + 8A1C83722AC328BD0096AF73 = { + CreatedOnToolsVersion = 15.0; + LastSwiftMigration = 1500; + }; + }; + }; + buildConfigurationList = 8A1C836E2AC328BD0096AF73 /* Build configuration list for PBXProject "llama.swiftui" */; + compatibilityVersion = "Xcode 14.0"; + developmentRegion = en; + hasScannedForEncodings = 0; + knownRegions = ( + en, + Base, + ); + mainGroup = 8A1C836A2AC328BD0096AF73; + packageReferences = ( + ); + productRefGroup = 8A1C83742AC328BD0096AF73 /* Products */; + projectDirPath = ""; + projectRoot = ""; + targets = ( + 8A1C83722AC328BD0096AF73 /* llama.swiftui */, + ); + }; +/* End PBXProject section */ + +/* Begin PBXResourcesBuildPhase section */ + 8A1C83712AC328BD0096AF73 /* Resources */ = { + isa = PBXResourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + 8A3F84242AC4C891005E2EE8 /* models in Resources */, + 8A1C837B2AC328BE0096AF73 /* Assets.xcassets in Resources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXResourcesBuildPhase section */ + +/* Begin PBXSourcesBuildPhase section */ + 8A1C836F2AC328BD0096AF73 /* Sources */ = { + isa = PBXSourcesBuildPhase; + buildActionMask = 2147483647; + files = ( + F1FE20E22B465ECA00B45541 /* LoadCustomButton.swift in Sources */, + 8A907F332AC7138A006146EA /* LibLlama.swift in Sources */, + 8A9F7C4D2AC332EE008AE1EA /* LlamaState.swift in Sources */, + 8A1C83792AC328BD0096AF73 /* ContentView.swift in Sources */, + 8A1C83772AC328BD0096AF73 /* llama_swiftuiApp.swift in Sources */, + 7FA3D2B32B2EA2F600543F92 /* DownloadButton.swift in Sources */, + 79E1D9CD2B4CD16E005F8E46 /* InputButton.swift in Sources */, + ); + runOnlyForDeploymentPostprocessing = 0; + }; +/* End PBXSourcesBuildPhase section */ + +/* Begin XCBuildConfiguration section */ + 8A1C837F2AC328BE0096AF73 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = dwarf; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_TESTABILITY = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_DYNAMIC_NO_PIC = NO; + GCC_NO_COMMON_BLOCKS = YES; + GCC_OPTIMIZATION_LEVEL = 0; + GCC_PREPROCESSOR_DEFINITIONS = ( + "DEBUG=1", + "$(inherited)", + ); + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 17.0; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; + MTL_FAST_MATH = YES; + ONLY_ACTIVE_ARCH = YES; + SDKROOT = iphoneos; + SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + }; + name = Debug; + }; + 8A1C83802AC328BE0096AF73 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ALWAYS_SEARCH_USER_PATHS = NO; + ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; + CLANG_ANALYZER_NONNULL = YES; + CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; + CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; + CLANG_ENABLE_MODULES = YES; + CLANG_ENABLE_OBJC_ARC = YES; + CLANG_ENABLE_OBJC_WEAK = YES; + CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; + CLANG_WARN_BOOL_CONVERSION = YES; + CLANG_WARN_COMMA = YES; + CLANG_WARN_CONSTANT_CONVERSION = YES; + CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; + CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; + CLANG_WARN_DOCUMENTATION_COMMENTS = YES; + CLANG_WARN_EMPTY_BODY = YES; + CLANG_WARN_ENUM_CONVERSION = YES; + CLANG_WARN_INFINITE_RECURSION = YES; + CLANG_WARN_INT_CONVERSION = YES; + CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; + CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; + CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; + CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; + CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; + CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; + CLANG_WARN_STRICT_PROTOTYPES = YES; + CLANG_WARN_SUSPICIOUS_MOVE = YES; + CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; + CLANG_WARN_UNREACHABLE_CODE = YES; + CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; + COPY_PHASE_STRIP = NO; + DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; + ENABLE_NS_ASSERTIONS = NO; + ENABLE_STRICT_OBJC_MSGSEND = YES; + ENABLE_USER_SCRIPT_SANDBOXING = YES; + GCC_C_LANGUAGE_STANDARD = gnu17; + GCC_NO_COMMON_BLOCKS = YES; + GCC_WARN_64_TO_32_BIT_CONVERSION = YES; + GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; + GCC_WARN_UNDECLARED_SELECTOR = YES; + GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; + GCC_WARN_UNUSED_FUNCTION = YES; + GCC_WARN_UNUSED_VARIABLE = YES; + IPHONEOS_DEPLOYMENT_TARGET = 17.0; + LOCALIZATION_PREFERS_STRING_CATALOGS = YES; + MTL_ENABLE_DEBUG_INFO = NO; + MTL_FAST_MATH = YES; + SDKROOT = iphoneos; + SWIFT_COMPILATION_MODE = wholemodule; + VALIDATE_PRODUCT = YES; + }; + name = Release; + }; + 8A1C83822AC328BE0096AF73 /* Debug */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_TEAM = K5UQJPP73A; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "com.bachittle.llama-swift"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator xros xrsimulator"; + SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2,7"; + }; + name = Debug; + }; + 8A1C83832AC328BE0096AF73 /* Release */ = { + isa = XCBuildConfiguration; + buildSettings = { + ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; + CLANG_ENABLE_MODULES = YES; + CODE_SIGN_STYLE = Automatic; + CURRENT_PROJECT_VERSION = 1; + DEVELOPMENT_TEAM = K5UQJPP73A; + ENABLE_PREVIEWS = YES; + GENERATE_INFOPLIST_FILE = YES; + INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; + INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; + INFOPLIST_KEY_UILaunchScreen_Generation = YES; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; + IPHONEOS_DEPLOYMENT_TARGET = 16.0; + LD_RUNPATH_SEARCH_PATHS = ( + "$(inherited)", + "@executable_path/Frameworks", + ); + MARKETING_VERSION = 1.0; + PRODUCT_BUNDLE_IDENTIFIER = "com.bachittle.llama-swift"; + PRODUCT_NAME = "$(TARGET_NAME)"; + SUPPORTED_PLATFORMS = "iphoneos iphonesimulator xros xrsimulator"; + SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; + SWIFT_EMIT_LOC_STRINGS = YES; + SWIFT_VERSION = 5.0; + TARGETED_DEVICE_FAMILY = "1,2,7"; + }; + name = Release; + }; +/* End XCBuildConfiguration section */ + +/* Begin XCConfigurationList section */ + 8A1C836E2AC328BD0096AF73 /* Build configuration list for PBXProject "llama.swiftui" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 8A1C837F2AC328BE0096AF73 /* Debug */, + 8A1C83802AC328BE0096AF73 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; + 8A1C83812AC328BE0096AF73 /* Build configuration list for PBXNativeTarget "llama.swiftui" */ = { + isa = XCConfigurationList; + buildConfigurations = ( + 8A1C83822AC328BE0096AF73 /* Debug */, + 8A1C83832AC328BE0096AF73 /* Release */, + ); + defaultConfigurationIsVisible = 0; + defaultConfigurationName = Release; + }; +/* End XCConfigurationList section */ + }; + rootObject = 8A1C836B2AC328BD0096AF73 /* Project object */; +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui.xcodeproj/project.xcworkspace/contents.xcworkspacedata b/llama.cpp/examples/llama.swiftui/llama.swiftui.xcodeproj/project.xcworkspace/contents.xcworkspacedata new file mode 100644 index 0000000..919434a --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui.xcodeproj/project.xcworkspace/contents.xcworkspacedata @@ -0,0 +1,7 @@ +<?xml version="1.0" encoding="UTF-8"?> +<Workspace + version = "1.0"> + <FileRef + location = "self:"> + </FileRef> +</Workspace> diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/Assets.xcassets/AppIcon.appiconset/Contents.json b/llama.cpp/examples/llama.swiftui/llama.swiftui/Assets.xcassets/AppIcon.appiconset/Contents.json new file mode 100644 index 0000000..13613e3 --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/Assets.xcassets/AppIcon.appiconset/Contents.json @@ -0,0 +1,13 @@ +{ + "images" : [ + { + "idiom" : "universal", + "platform" : "ios", + "size" : "1024x1024" + } + ], + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/Assets.xcassets/Contents.json b/llama.cpp/examples/llama.swiftui/llama.swiftui/Assets.xcassets/Contents.json new file mode 100644 index 0000000..73c0059 --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/Assets.xcassets/Contents.json @@ -0,0 +1,6 @@ +{ + "info" : { + "author" : "xcode", + "version" : 1 + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift b/llama.cpp/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift new file mode 100644 index 0000000..b8f6a31 --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift @@ -0,0 +1,196 @@ +import Foundation + +struct Model: Identifiable { + var id = UUID() + var name: String + var url: String + var filename: String + var status: String? +} + +@MainActor +class LlamaState: ObservableObject { + @Published var messageLog = "" + @Published var cacheCleared = false + @Published var downloadedModels: [Model] = [] + @Published var undownloadedModels: [Model] = [] + let NS_PER_S = 1_000_000_000.0 + + private var llamaContext: LlamaContext? + private var defaultModelUrl: URL? { + Bundle.main.url(forResource: "ggml-model", withExtension: "gguf", subdirectory: "models") + // Bundle.main.url(forResource: "llama-2-7b-chat", withExtension: "Q2_K.gguf", subdirectory: "models") + } + + init() { + loadModelsFromDisk() + loadDefaultModels() + } + + private func loadModelsFromDisk() { + do { + let documentsURL = getDocumentsDirectory() + let modelURLs = try FileManager.default.contentsOfDirectory(at: documentsURL, includingPropertiesForKeys: nil, options: [.skipsHiddenFiles, .skipsSubdirectoryDescendants]) + for modelURL in modelURLs { + let modelName = modelURL.deletingPathExtension().lastPathComponent + downloadedModels.append(Model(name: modelName, url: "", filename: modelURL.lastPathComponent, status: "downloaded")) + } + } catch { + print("Error loading models from disk: \(error)") + } + } + + private func loadDefaultModels() { + do { + try loadModel(modelUrl: defaultModelUrl) + } catch { + messageLog += "Error!\n" + } + + for model in defaultModels { + let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename) + if FileManager.default.fileExists(atPath: fileURL.path) { + + } else { + var undownloadedModel = model + undownloadedModel.status = "download" + undownloadedModels.append(undownloadedModel) + } + } + } + + func getDocumentsDirectory() -> URL { + let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) + return paths[0] + } + private let defaultModels: [Model] = [ + Model(name: "TinyLlama-1.1B (Q4_0, 0.6 GiB)",url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf", status: "download"), + Model( + name: "TinyLlama-1.1B Chat (Q8_0, 1.1 GiB)", + url: "https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf?download=true", + filename: "tinyllama-1.1b-chat-v1.0.Q8_0.gguf", status: "download" + ), + + Model( + name: "TinyLlama-1.1B (F16, 2.2 GiB)", + url: "https://huggingface.co/ggml-org/models/resolve/main/tinyllama-1.1b/ggml-model-f16.gguf?download=true", + filename: "tinyllama-1.1b-f16.gguf", status: "download" + ), + + Model( + name: "Phi-2.7B (Q4_0, 1.6 GiB)", + url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q4_0.gguf?download=true", + filename: "phi-2-q4_0.gguf", status: "download" + ), + + Model( + name: "Phi-2.7B (Q8_0, 2.8 GiB)", + url: "https://huggingface.co/ggml-org/models/resolve/main/phi-2/ggml-model-q8_0.gguf?download=true", + filename: "phi-2-q8_0.gguf", status: "download" + ), + + Model( + name: "Mistral-7B-v0.1 (Q4_0, 3.8 GiB)", + url: "https://huggingface.co/TheBloke/Mistral-7B-v0.1-GGUF/resolve/main/mistral-7b-v0.1.Q4_0.gguf?download=true", + filename: "mistral-7b-v0.1.Q4_0.gguf", status: "download" + ), + Model( + name: "OpenHermes-2.5-Mistral-7B (Q3_K_M, 3.52 GiB)", + url: "https://huggingface.co/TheBloke/OpenHermes-2.5-Mistral-7B-GGUF/resolve/main/openhermes-2.5-mistral-7b.Q3_K_M.gguf?download=true", + filename: "openhermes-2.5-mistral-7b.Q3_K_M.gguf", status: "download" + ) + ] + func loadModel(modelUrl: URL?) throws { + if let modelUrl { + messageLog += "Loading model...\n" + llamaContext = try LlamaContext.create_context(path: modelUrl.path()) + messageLog += "Loaded model \(modelUrl.lastPathComponent)\n" + + // Assuming that the model is successfully loaded, update the downloaded models + updateDownloadedModels(modelName: modelUrl.lastPathComponent, status: "downloaded") + } else { + messageLog += "Load a model from the list below\n" + } + } + + + private func updateDownloadedModels(modelName: String, status: String) { + undownloadedModels.removeAll { $0.name == modelName } + } + + + func complete(text: String) async { + guard let llamaContext else { + return + } + + let t_start = DispatchTime.now().uptimeNanoseconds + await llamaContext.completion_init(text: text) + let t_heat_end = DispatchTime.now().uptimeNanoseconds + let t_heat = Double(t_heat_end - t_start) / NS_PER_S + + messageLog += "\(text)" + + Task.detached { + while await !llamaContext.is_done { + let result = await llamaContext.completion_loop() + await MainActor.run { + self.messageLog += "\(result)" + } + } + + let t_end = DispatchTime.now().uptimeNanoseconds + let t_generation = Double(t_end - t_heat_end) / self.NS_PER_S + let tokens_per_second = Double(await llamaContext.n_len) / t_generation + + await llamaContext.clear() + + await MainActor.run { + self.messageLog += """ + \n + Done + Heat up took \(t_heat)s + Generated \(tokens_per_second) t/s\n + """ + } + } + } + + func bench() async { + guard let llamaContext else { + return + } + + messageLog += "\n" + messageLog += "Running benchmark...\n" + messageLog += "Model info: " + messageLog += await llamaContext.model_info() + "\n" + + let t_start = DispatchTime.now().uptimeNanoseconds + let _ = await llamaContext.bench(pp: 8, tg: 4, pl: 1) // heat up + let t_end = DispatchTime.now().uptimeNanoseconds + + let t_heat = Double(t_end - t_start) / NS_PER_S + messageLog += "Heat up time: \(t_heat) seconds, please wait...\n" + + // if more than 5 seconds, then we're probably running on a slow device + if t_heat > 5.0 { + messageLog += "Heat up time is too long, aborting benchmark\n" + return + } + + let result = await llamaContext.bench(pp: 512, tg: 128, pl: 1, nr: 3) + + messageLog += "\(result)" + messageLog += "\n" + } + + func clear() async { + guard let llamaContext else { + return + } + + await llamaContext.clear() + messageLog = "" + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/Resources/models/.gitignore b/llama.cpp/examples/llama.swiftui/llama.swiftui/Resources/models/.gitignore new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/Resources/models/.gitignore diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift new file mode 100644 index 0000000..1c3cd9d --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift @@ -0,0 +1,156 @@ +import SwiftUI + +struct ContentView: View { + @StateObject var llamaState = LlamaState() + @State private var multiLineText = "" + @State private var showingHelp = false // To track if Help Sheet should be shown + + var body: some View { + NavigationView { + VStack { + ScrollView(.vertical, showsIndicators: true) { + Text(llamaState.messageLog) + .font(.system(size: 12)) + .frame(maxWidth: .infinity, alignment: .leading) + .padding() + .onTapGesture { + UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) + } + } + + TextEditor(text: $multiLineText) + .frame(height: 80) + .padding() + .border(Color.gray, width: 0.5) + + HStack { + Button("Send") { + sendText() + } + + Button("Bench") { + bench() + } + + Button("Clear") { + clear() + } + + Button("Copy") { + UIPasteboard.general.string = llamaState.messageLog + } + } + .buttonStyle(.bordered) + .padding() + + NavigationLink(destination: DrawerView(llamaState: llamaState)) { + Text("View Models") + } + .padding() + + } + .padding() + .navigationBarTitle("Model Settings", displayMode: .inline) + + } + } + + func sendText() { + Task { + await llamaState.complete(text: multiLineText) + multiLineText = "" + } + } + + func bench() { + Task { + await llamaState.bench() + } + } + + func clear() { + Task { + await llamaState.clear() + } + } + struct DrawerView: View { + + @ObservedObject var llamaState: LlamaState + @State private var showingHelp = false + func delete(at offsets: IndexSet) { + offsets.forEach { offset in + let model = llamaState.downloadedModels[offset] + let fileURL = getDocumentsDirectory().appendingPathComponent(model.filename) + do { + try FileManager.default.removeItem(at: fileURL) + } catch { + print("Error deleting file: \(error)") + } + } + + // Remove models from downloadedModels array + llamaState.downloadedModels.remove(atOffsets: offsets) + } + + func getDocumentsDirectory() -> URL { + let paths = FileManager.default.urls(for: .documentDirectory, in: .userDomainMask) + return paths[0] + } + var body: some View { + List { + Section(header: Text("Download Models From Hugging Face")) { + HStack { + InputButton(llamaState: llamaState) + } + } + Section(header: Text("Downloaded Models")) { + ForEach(llamaState.downloadedModels) { model in + DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename) + } + .onDelete(perform: delete) + } + Section(header: Text("Default Models")) { + ForEach(llamaState.undownloadedModels) { model in + DownloadButton(llamaState: llamaState, modelName: model.name, modelUrl: model.url, filename: model.filename) + } + } + + } + .listStyle(GroupedListStyle()) + .navigationBarTitle("Model Settings", displayMode: .inline).toolbar { + ToolbarItem(placement: .navigationBarTrailing) { + Button("Help") { + showingHelp = true + } + } + }.sheet(isPresented: $showingHelp) { // Sheet for help modal + NavigationView { + VStack(alignment: .leading) { + VStack(alignment: .leading) { + Text("1. Make sure the model is in GGUF Format") + .padding() + Text("2. Copy the download link of the quantized model") + .padding() + } + Spacer() + } + .navigationTitle("Help") + .navigationBarTitleDisplayMode(.inline) + .toolbar { + ToolbarItem(placement: .navigationBarTrailing) { + Button("Done") { + showingHelp = false + } + } + } + } + } + } + } +} + +struct ContentView_Previews: PreviewProvider { + static var previews: some View { + ContentView() + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift new file mode 100644 index 0000000..4584d6e --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift @@ -0,0 +1,124 @@ +import SwiftUI + +struct DownloadButton: View { + @ObservedObject private var llamaState: LlamaState + private var modelName: String + private var modelUrl: String + private var filename: String + + @State private var status: String + + @State private var downloadTask: URLSessionDownloadTask? + @State private var progress = 0.0 + @State private var observation: NSKeyValueObservation? + + private static func getFileURL(filename: String) -> URL { + FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename) + } + + private func checkFileExistenceAndUpdateStatus() { + } + + init(llamaState: LlamaState, modelName: String, modelUrl: String, filename: String) { + self.llamaState = llamaState + self.modelName = modelName + self.modelUrl = modelUrl + self.filename = filename + + let fileURL = DownloadButton.getFileURL(filename: filename) + status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download" + } + + private func download() { + status = "downloading" + print("Downloading model \(modelName) from \(modelUrl)") + guard let url = URL(string: modelUrl) else { return } + let fileURL = DownloadButton.getFileURL(filename: filename) + + downloadTask = URLSession.shared.downloadTask(with: url) { temporaryURL, response, error in + if let error = error { + print("Error: \(error.localizedDescription)") + return + } + + guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else { + print("Server error!") + return + } + + do { + if let temporaryURL = temporaryURL { + try FileManager.default.copyItem(at: temporaryURL, to: fileURL) + print("Writing to \(filename) completed") + + llamaState.cacheCleared = false + + let model = Model(name: modelName, url: modelUrl, filename: filename, status: "downloaded") + llamaState.downloadedModels.append(model) + status = "downloaded" + } + } catch let err { + print("Error: \(err.localizedDescription)") + } + } + + observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in + self.progress = progress.fractionCompleted + } + + downloadTask?.resume() + } + + var body: some View { + VStack { + if status == "download" { + Button(action: download) { + Text("Download " + modelName) + } + } else if status == "downloading" { + Button(action: { + downloadTask?.cancel() + status = "download" + }) { + Text("\(modelName) (Downloading \(Int(progress * 100))%)") + } + } else if status == "downloaded" { + Button(action: { + let fileURL = DownloadButton.getFileURL(filename: filename) + if !FileManager.default.fileExists(atPath: fileURL.path) { + download() + return + } + do { + try llamaState.loadModel(modelUrl: fileURL) + } catch let err { + print("Error: \(err.localizedDescription)") + } + }) { + Text("Load \(modelName)") + } + } else { + Text("Unknown status") + } + } + .onDisappear() { + downloadTask?.cancel() + } + .onChange(of: llamaState.cacheCleared) { newValue in + if newValue { + downloadTask?.cancel() + let fileURL = DownloadButton.getFileURL(filename: filename) + status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download" + } + } + } +} + +// #Preview { +// DownloadButton( +// llamaState: LlamaState(), +// modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q4_0)", +// modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true", +// filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf" +// ) +// } diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift new file mode 100644 index 0000000..c5ffbad --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/InputButton.swift @@ -0,0 +1,131 @@ +import SwiftUI + +struct InputButton: View { + @ObservedObject var llamaState: LlamaState + @State private var inputLink: String = "" + @State private var status: String = "download" + @State private var filename: String = "" + + @State private var downloadTask: URLSessionDownloadTask? + @State private var progress = 0.0 + @State private var observation: NSKeyValueObservation? + + private static func extractModelInfo(from link: String) -> (modelName: String, filename: String)? { + guard let url = URL(string: link), + let lastPathComponent = url.lastPathComponent.components(separatedBy: ".").first, + let modelName = lastPathComponent.components(separatedBy: "-").dropLast().joined(separator: "-").removingPercentEncoding, + let filename = lastPathComponent.removingPercentEncoding else { + return nil + } + + return (modelName, filename) + } + + private static func getFileURL(filename: String) -> URL { + FileManager.default.urls(for: .documentDirectory, in: .userDomainMask)[0].appendingPathComponent(filename) + } + + private func download() { + guard let extractedInfo = InputButton.extractModelInfo(from: inputLink) else { + // Handle invalid link or extraction failure + return + } + + let (modelName, filename) = extractedInfo + self.filename = filename // Set the state variable + + status = "downloading" + print("Downloading model \(modelName) from \(inputLink)") + guard let url = URL(string: inputLink) else { return } + let fileURL = InputButton.getFileURL(filename: filename) + + downloadTask = URLSession.shared.downloadTask(with: url) { temporaryURL, response, error in + if let error = error { + print("Error: \(error.localizedDescription)") + return + } + + guard let response = response as? HTTPURLResponse, (200...299).contains(response.statusCode) else { + print("Server error!") + return + } + + do { + if let temporaryURL = temporaryURL { + try FileManager.default.copyItem(at: temporaryURL, to: fileURL) + print("Writing to \(filename) completed") + + llamaState.cacheCleared = false + + let model = Model(name: modelName, url: self.inputLink, filename: filename, status: "downloaded") + llamaState.downloadedModels.append(model) + status = "downloaded" + } + } catch let err { + print("Error: \(err.localizedDescription)") + } + } + + observation = downloadTask?.progress.observe(\.fractionCompleted) { progress, _ in + self.progress = progress.fractionCompleted + } + + downloadTask?.resume() + } + + var body: some View { + VStack { + HStack { + TextField("Paste Quantized Download Link", text: $inputLink) + .textFieldStyle(RoundedBorderTextFieldStyle()) + + Button(action: { + downloadTask?.cancel() + status = "download" + }) { + Text("Cancel") + } + } + + if status == "download" { + Button(action: download) { + Text("Download Custom Model") + } + } else if status == "downloading" { + Button(action: { + downloadTask?.cancel() + status = "download" + }) { + Text("Downloading \(Int(progress * 100))%") + } + } else if status == "downloaded" { + Button(action: { + let fileURL = InputButton.getFileURL(filename: self.filename) + if !FileManager.default.fileExists(atPath: fileURL.path) { + download() + return + } + do { + try llamaState.loadModel(modelUrl: fileURL) + } catch let err { + print("Error: \(err.localizedDescription)") + } + }) { + Text("Load Custom Model") + } + } else { + Text("Unknown status") + } + } + .onDisappear() { + downloadTask?.cancel() + } + .onChange(of: llamaState.cacheCleared) { newValue in + if newValue { + downloadTask?.cancel() + let fileURL = InputButton.getFileURL(filename: self.filename) + status = FileManager.default.fileExists(atPath: fileURL.path) ? "downloaded" : "download" + } + } + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/LoadCustomButton.swift b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/LoadCustomButton.swift new file mode 100644 index 0000000..4315dbe --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/UI/LoadCustomButton.swift @@ -0,0 +1,44 @@ +import SwiftUI +import UniformTypeIdentifiers + +struct LoadCustomButton: View { + @ObservedObject private var llamaState: LlamaState + @State private var showFileImporter = false + + init(llamaState: LlamaState) { + self.llamaState = llamaState + } + + var body: some View { + VStack { + Button(action: { + showFileImporter = true + }) { + Text("Load Custom Model") + } + } + .fileImporter( + isPresented: $showFileImporter, + allowedContentTypes: [UTType(filenameExtension: "gguf", conformingTo: .data)!], + allowsMultipleSelection: false + ) { result in + switch result { + case .success(let files): + files.forEach { file in + let gotAccess = file.startAccessingSecurityScopedResource() + if !gotAccess { return } + + do { + try llamaState.loadModel(modelUrl: file.absoluteURL) + } catch let err { + print("Error: \(err.localizedDescription)") + } + + file.stopAccessingSecurityScopedResource() + } + case .failure(let error): + print(error) + } + } + } +} diff --git a/llama.cpp/examples/llama.swiftui/llama.swiftui/llama_swiftuiApp.swift b/llama.cpp/examples/llama.swiftui/llama.swiftui/llama_swiftuiApp.swift new file mode 100644 index 0000000..cccda8a --- /dev/null +++ b/llama.cpp/examples/llama.swiftui/llama.swiftui/llama_swiftuiApp.swift @@ -0,0 +1,10 @@ +import SwiftUI + +@main +struct llama_swiftuiApp: App { + var body: some Scene { + WindowGroup { + ContentView() + } + } +} diff --git a/llama.cpp/examples/llama.vim b/llama.cpp/examples/llama.vim new file mode 100644 index 0000000..736802d --- /dev/null +++ b/llama.cpp/examples/llama.vim @@ -0,0 +1,783 @@ +" LLM-based text completion using llama.cpp +" +" requires: +" +" - neovim or vim +" - curl +" - llama.cpp server instance +" - FIM-compatible model +" +" sample config: +" +" - Tab - accept the current suggestion +" - Shift+Tab - accept just the first line of the suggestion +" - Ctrl+F - toggle FIM completion manually +" +" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim +" +" start the llama.cpp server with a FIM-compatible model. for example: +" +" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 512 --batch-size 1024 --cache-reuse 256 +" +" --batch-size [512, model max context] +" +" adjust the batch size to control how much of the provided local context will be used during the inference +" lower values will use smaller part of the context around the cursor, which will result in faster processing +" +" --ubatch-size [64, 2048] +" +" chunks the batch into smaller chunks for faster processing +" depends on the specific hardware. use llama-bench to profile and determine the best size +" +" --cache-reuse (ge:llama_config.n_predict, 1024] +" +" this should be either 0 (disabled) or strictly larger than g:llama_config.n_predict +" using non-zero value enables context reuse on the server side which dramatically improves the performance at +" large contexts. a value of 256 should be good for all cases +" +" run this once to initialise llama.vim: +" +" :call llama#init() +" +" more info: https://github.com/ggml-org/llama.cpp/pull/9787 +" + +" colors (adjust to your liking) +highlight llama_hl_hint guifg=#ff772f ctermfg=202 +highlight llama_hl_info guifg=#77ff2f ctermfg=119 + +" general parameters: +" +" endpoint: llama.cpp server endpoint +" n_prefix: number of lines before the cursor location to include in the local prefix +" n_suffix: number of lines after the cursor location to include in the local suffix +" n_predict: max number of tokens to predict +" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported) +" t_max_predict_ms: max alloted time for the prediction +" show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline) +" auto_fim: trigger FIM completion automatically on cursor movement +" max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor +" +" ring buffer of chunks, accumulated with time upon: +" +" - completion request +" - yank +" - entering a buffer +" - leaving a buffer +" - writing a file +" +" parameters for the ring-buffer with extra context: +" +" ring_n_chunks: max number of chunks to pass as extra context to the server (0 to disable) +" ring_chunk_size: max size of the chunks (in number of lines) +" note: adjust these numbers so that you don't overrun your context +" at ring_n_chunks = 64 and ring_chunk_size = 64 you need ~32k context +" ring_scope: the range around the cursor position (in number of lines) for gathering chunks after FIM +" ring_update_ms: how often to process queued chunks in normal mode +" +let s:default_config = { + \ 'endpoint': 'http://127.0.0.1:8012/infill', + \ 'n_prefix': 256, + \ 'n_suffix': 64, + \ 'n_predict': 128, + \ 't_max_prompt_ms': 500, + \ 't_max_predict_ms': 3000, + \ 'show_info': 2, + \ 'auto_fim': v:true, + \ 'max_line_suffix': 8, + \ 'ring_n_chunks': 64, + \ 'ring_chunk_size': 64, + \ 'ring_scope': 1024, + \ 'ring_update_ms': 1000, + \ } + +let g:llama_config = get(g:, 'llama_config', s:default_config) + +function! s:get_indent(str) + let l:count = 0 + for i in range(len(a:str)) + if a:str[i] == "\t" + let l:count += &tabstop - 1 + else + break + endif + endfor + return l:count +endfunction + +function! s:rand(i0, i1) abort + return a:i0 + rand() % (a:i1 - a:i0 + 1) +endfunction + +function! llama#init() + if !executable('curl') + echohl WarningMsg + echo 'llama.vim requires the "curl" command to be available' + echohl None + return + endif + + let s:pos_x = 0 " cursor position upon start of completion + let s:pos_y = 0 + + let s:line_cur = '' + + let s:line_cur_prefix = '' + let s:line_cur_suffix = '' + + let s:ring_chunks = [] " current set of chunks used as extra context + let s:ring_queued = [] " chunks that are queued to be sent for processing + let s:ring_n_evict = 0 + + let s:hint_shown = v:false + let s:pos_y_pick = -9999 " last y where we picked a chunk + let s:pos_dx = 0 + let s:content = [] + let s:can_accept = v:false + + let s:timer_fim = -1 + let s:t_fim_start = reltime() " used to measure total FIM time + let s:t_last_move = reltime() " last time the cursor moved + + let s:current_job = v:null + + let s:ghost_text_nvim = exists('*nvim_buf_get_mark') + let s:ghost_text_vim = has('textprop') + + if s:ghost_text_vim + let s:hlgroup_hint = 'llama_hl_hint' + let s:hlgroup_info = 'llama_hl_info' + + if empty(prop_type_get(s:hlgroup_hint)) + call prop_type_add(s:hlgroup_hint, {'highlight': s:hlgroup_hint}) + endif + if empty(prop_type_get(s:hlgroup_info)) + call prop_type_add(s:hlgroup_info, {'highlight': s:hlgroup_info}) + endif + endif + + augroup llama + autocmd! + autocmd InsertEnter * inoremap <expr> <silent> <C-F> llama#fim_inline(v:false) + autocmd InsertLeavePre * call llama#fim_cancel() + + autocmd CursorMoved * call s:on_move() + autocmd CursorMovedI * call s:on_move() + autocmd CompleteChanged * call llama#fim_cancel() + + if g:llama_config.auto_fim + autocmd CursorMovedI * call llama#fim(v:true) + endif + + " gather chunks upon yanking + autocmd TextYankPost * if v:event.operator ==# 'y' | call s:pick_chunk(v:event.regcontents, v:false, v:true) | endif + + " gather chunks upon entering/leaving a buffer + autocmd BufEnter * call timer_start(100, {-> s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true)}) + autocmd BufLeave * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true) + + " gather chunk upon saving the file + autocmd BufWritePost * call s:pick_chunk(getline(max([1, line('.') - g:llama_config.ring_chunk_size/2]), min([line('.') + g:llama_config.ring_chunk_size/2, line('$')])), v:true, v:true) + augroup END + + silent! call llama#fim_cancel() + + " init background update of the ring buffer + if g:llama_config.ring_n_chunks > 0 + call s:ring_update() + endif +endfunction + +" compute how similar two chunks of text are +" 0 - no similarity, 1 - high similarity +" TODO: figure out something better +function! s:chunk_sim(c0, c1) + let l:lines0 = len(a:c0) + let l:lines1 = len(a:c1) + + let l:common = 0 + + for l:line0 in a:c0 + for l:line1 in a:c1 + if l:line0 == l:line1 + let l:common += 1 + break + endif + endfor + endfor + + return 2.0 * l:common / (l:lines0 + l:lines1) +endfunction + +" pick a random chunk of size g:llama_config.ring_chunk_size from the provided text and queue it for processing +" +" no_mod - do not pick chunks from buffers with pending changes +" do_evict - evict chunks that are very similar to the new one +" +function! s:pick_chunk(text, no_mod, do_evict) + " do not pick chunks from buffers with pending changes or buffers that are not files + if a:no_mod && (getbufvar(bufnr('%'), '&modified') || !buflisted(bufnr('%')) || !filereadable(expand('%'))) + return + endif + + " if the extra context option is disabled - do nothing + if g:llama_config.ring_n_chunks <= 0 + return + endif + + " don't pick very small chunks + if len(a:text) < 3 + return + endif + + if len(a:text) + 1 < g:llama_config.ring_chunk_size + let l:chunk = a:text + else + let l:l0 = s:rand(0, max([0, len(a:text) - g:llama_config.ring_chunk_size/2])) + let l:l1 = min([l:l0 + g:llama_config.ring_chunk_size/2, len(a:text)]) + + let l:chunk = a:text[l:l0:l:l1] + endif + + let l:chunk_str = join(l:chunk, "\n") . "\n" + + " check if this chunk is already added + let l:exist = v:false + + for i in range(len(s:ring_chunks)) + if s:ring_chunks[i].data == l:chunk + let l:exist = v:true + break + endif + endfor + + for i in range(len(s:ring_queued)) + if s:ring_queued[i].data == l:chunk + let l:exist = v:true + break + endif + endfor + + if l:exist + return + endif + + " evict queued chunks that are very similar to the new one + for i in range(len(s:ring_queued) - 1, 0, -1) + if s:chunk_sim(s:ring_queued[i].data, l:chunk) > 0.9 + if a:do_evict + call remove(s:ring_queued, i) + let s:ring_n_evict += 1 + else + return + endif + endif + endfor + + " also from s:ring_chunks + for i in range(len(s:ring_chunks) - 1, 0, -1) + if s:chunk_sim(s:ring_chunks[i].data, l:chunk) > 0.9 + if a:do_evict + call remove(s:ring_chunks, i) + let s:ring_n_evict += 1 + else + return + endif + endif + endfor + + " TODO: become parameter ? + if len(s:ring_queued) == 16 + call remove(s:ring_queued, 0) + endif + + call add(s:ring_queued, {'data': l:chunk, 'str': l:chunk_str, 'time': reltime(), 'filename': expand('%')}) + + "let &statusline = 'extra context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued) +endfunction + +" picks a queued chunk, sends it for processing and adds it to s:ring_chunks +" called every g:llama_config.ring_update_ms +function! s:ring_update() + call timer_start(g:llama_config.ring_update_ms, {-> s:ring_update()}) + + " update only if in normal mode or if the cursor hasn't moved for a while + if mode() !=# 'n' && reltimefloat(reltime(s:t_last_move)) < 3.0 + return + endif + + if len(s:ring_queued) == 0 + return + endif + + " move the first queued chunk to the ring buffer + if len(s:ring_chunks) == g:llama_config.ring_n_chunks + call remove(s:ring_chunks, 0) + endif + + call add(s:ring_chunks, remove(s:ring_queued, 0)) + + "let &statusline = 'updated context: ' . len(s:ring_chunks) . ' / ' . len(s:ring_queued) + + " send asynchronous job with the new extra context so that it is ready for the next FIM + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, { + \ 'text': l:chunk.str, + \ 'time': l:chunk.time, + \ 'filename': l:chunk.filename + \ }) + endfor + + " no samplers needed here + let l:request = json_encode({ + \ 'input_prefix': "", + \ 'input_suffix': "", + \ 'input_extra': l:extra_context, + \ 'prompt': "", + \ 'n_predict': 1, + \ 'temperature': 0.0, + \ 'stream': v:false, + \ 'samplers': ["temperature"], + \ 'cache_prompt': v:true, + \ 't_max_prompt_ms': 1, + \ 't_max_predict_ms': 1 + \ }) + + let l:curl_command = [ + \ "curl", + \ "--silent", + \ "--no-buffer", + \ "--request", "POST", + \ "--url", g:llama_config.endpoint, + \ "--header", "Content-Type: application/json", + \ "--data", l:request + \ ] + + " no callbacks because we don't need to process the response + if s:ghost_text_nvim + call jobstart(l:curl_command, {}) + elseif s:ghost_text_vim + call job_start(l:curl_command, {}) + endif +endfunction + +" necessary for 'inoremap <expr>' +function! llama#fim_inline(is_auto) abort + call llama#fim(a:is_auto) + return '' +endfunction + +" the main FIM call +" takes local context around the cursor and sends it together with the extra context to the server for completion +function! llama#fim(is_auto) abort + " we already have a suggestion for the current cursor position + if s:hint_shown && !a:is_auto + call llama#fim_cancel() + return + endif + + call llama#fim_cancel() + + " avoid sending repeated requests too fast + if reltimefloat(reltime(s:t_fim_start)) < 0.6 + if s:timer_fim != -1 + call timer_stop(s:timer_fim) + let s:timer_fim = -1 + endif + + let s:t_fim_start = reltime() + let s:timer_fim = timer_start(600, {-> llama#fim(v:true)}) + return + endif + + let s:t_fim_start = reltime() + + let s:content = [] + let s:can_accept = v:false + + let s:pos_x = col('.') - 1 + let s:pos_y = line('.') + let l:max_y = line('$') + + let l:lines_prefix = getline(max([1, s:pos_y - g:llama_config.n_prefix]), s:pos_y - 1) + let l:lines_suffix = getline(s:pos_y + 1, min([l:max_y, s:pos_y + g:llama_config.n_suffix])) + + let s:line_cur = getline('.') + + let s:line_cur_prefix = strpart(s:line_cur, 0, s:pos_x) + let s:line_cur_suffix = strpart(s:line_cur, s:pos_x) + + if a:is_auto && len(s:line_cur_suffix) > g:llama_config.max_line_suffix + return + endif + + let l:prefix = "" + \ . join(l:lines_prefix, "\n") + \ . "\n" + + let l:prompt = "" + \ . s:line_cur_prefix + + let l:suffix = "" + \ . s:line_cur_suffix + \ . "\n" + \ . join(l:lines_suffix, "\n") + \ . "\n" + + " prepare the extra context data + let l:extra_context = [] + for l:chunk in s:ring_chunks + call add(l:extra_context, { + \ 'text': l:chunk.str, + \ 'time': l:chunk.time, + \ 'filename': l:chunk.filename + \ }) + endfor + + " the indentation of the current line + let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*')) + + let l:request = json_encode({ + \ 'input_prefix': l:prefix, + \ 'input_suffix': l:suffix, + \ 'input_extra': l:extra_context, + \ 'prompt': l:prompt, + \ 'n_predict': g:llama_config.n_predict, + \ 'n_indent': l:indent, + \ 'top_k': 40, + \ 'top_p': 0.99, + \ 'stream': v:false, + \ 'samplers': ["top_k", "top_p", "infill"], + \ 'cache_prompt': v:true, + \ 't_max_prompt_ms': g:llama_config.t_max_prompt_ms, + \ 't_max_predict_ms': g:llama_config.t_max_predict_ms + \ }) + + let l:curl_command = [ + \ "curl", + \ "--silent", + \ "--no-buffer", + \ "--request", "POST", + \ "--url", g:llama_config.endpoint, + \ "--header", "Content-Type: application/json", + \ "--data", l:request + \ ] + + if s:current_job != v:null + if s:ghost_text_nvim + call jobstop(s:current_job) + elseif s:ghost_text_vim + call job_stop(s:current_job) + endif + endif + + " send the request asynchronously + if s:ghost_text_nvim + let s:current_job = jobstart(l:curl_command, { + \ 'on_stdout': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]), + \ 'on_exit': function('s:fim_on_exit'), + \ 'stdout_buffered': v:true + \ }) + elseif s:ghost_text_vim + let s:current_job = job_start(l:curl_command, { + \ 'out_cb': function('s:fim_on_stdout', [s:pos_x, s:pos_y, a:is_auto]), + \ 'exit_cb': function('s:fim_on_exit') + \ }) + endif + + " TODO: per-file location + let l:delta_y = abs(s:pos_y - s:pos_y_pick) + + " gather some extra context nearby and process it in the background + " only gather chunks if the cursor has moved a lot + " TODO: something more clever? reranking? + if a:is_auto && l:delta_y > 32 + " expand the prefix even further + call s:pick_chunk(getline(max([1, s:pos_y - g:llama_config.ring_scope]), max([1, s:pos_y - g:llama_config.n_prefix])), v:false, v:false) + + " pick a suffix chunk + call s:pick_chunk(getline(min([l:max_y, s:pos_y + g:llama_config.n_suffix]), min([l:max_y, s:pos_y + g:llama_config.n_suffix + g:llama_config.ring_chunk_size])), v:false, v:false) + + let s:pos_y_pick = s:pos_y + endif +endfunction + +" if first_line == v:true accept only the first line of the response +function! llama#fim_accept(first_line) + " insert the suggestion at the cursor location + if s:can_accept && len(s:content) > 0 + call setline(s:pos_y, s:line_cur[:(s:pos_x - 1)] . s:content[0]) + if len(s:content) > 1 + if !a:first_line + call append(s:pos_y, s:content[1:-1]) + endif + endif + + " move the cursor to the end of the accepted text + if !a:first_line && len(s:content) > 1 + call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx + 1) + else + call cursor(s:pos_y, s:pos_x + len(s:content[0])) + endif + endif + + call llama#fim_cancel() +endfunction + +function! llama#fim_cancel() + let s:hint_shown = v:false + + " clear the virtual text + let l:bufnr = bufnr('%') + + if s:ghost_text_nvim + let l:id_vt_fim = nvim_create_namespace('vt_fim') + call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1) + elseif s:ghost_text_vim + call prop_remove({'type': s:hlgroup_hint, 'all': v:true}) + call prop_remove({'type': s:hlgroup_info, 'all': v:true}) + endif + + " remove the mappings + silent! iunmap <buffer> <Tab> + silent! iunmap <buffer> <S-Tab> + silent! iunmap <buffer> <Esc> +endfunction + +function! s:on_move() + let s:t_last_move = reltime() + + call llama#fim_cancel() +endfunction + +" callback that processes the FIM result from the server and displays the suggestion +function! s:fim_on_stdout(pos_x, pos_y, is_auto, job_id, data, event = v:null) + if s:ghost_text_nvim + let l:raw = join(a:data, "\n") + elseif s:ghost_text_vim + let l:raw = a:data + endif + + if len(l:raw) == 0 + return + endif + + if a:pos_x != col('.') - 1 || a:pos_y != line('.') + return + endif + + " show the suggestion only in insert mode + if mode() !=# 'i' + return + endif + + let s:pos_x = a:pos_x + let s:pos_y = a:pos_y + + let s:can_accept = v:true + let l:has_info = v:false + + if s:can_accept && v:shell_error + if !a:is_auto + call add(s:content, "<| curl error: is the server on? |>") + endif + let s:can_accept = v:false + endif + + let l:n_prompt = 0 + let l:t_prompt_ms = 1.0 + let l:s_prompt = 0 + + let l:n_predict = 0 + let l:t_predict_ms = 1.0 + let l:s_predict = 0 + + " get the generated suggestion + if s:can_accept + let l:response = json_decode(l:raw) + + for l:part in split(get(l:response, 'content', ''), "\n", 1) + call add(s:content, l:part) + endfor + + " remove trailing new lines + while len(s:content) > 0 && s:content[-1] == "" + call remove(s:content, -1) + endwhile + + let l:generation_settings = get(l:response, 'generation_settings', {}) + let l:n_ctx = get(l:generation_settings, 'n_ctx', 0) + + let l:n_cached = get(l:response, 'tokens_cached', 0) + let l:truncated = get(l:response, 'truncated', v:false) + + " if response.timings is available + if len(get(l:response, 'timings', {})) > 0 + let l:has_info = v:true + let l:timings = get(l:response, 'timings', {}) + + let l:n_prompt = get(l:timings, 'prompt_n', 0) + let l:t_prompt_ms = get(l:timings, 'prompt_ms', 1) + let l:s_prompt = get(l:timings, 'prompt_per_second', 0) + + let l:n_predict = get(l:timings, 'predicted_n', 0) + let l:t_predict_ms = get(l:timings, 'predicted_ms', 1) + let l:s_predict = get(l:timings, 'predicted_per_second', 0) + endif + endif + + if len(s:content) == 0 + call add(s:content, "") + let s:can_accept = v:false + endif + + if len(s:content) == 0 + return + endif + + " NOTE: the following is logic for discarding predictions that repeat existing text + " the code is quite ugly and there is very likely a simpler and more canonical way to implement this + " + " still, I wonder if there is some better way that avoids having to do these special hacks? + " on one hand, the LLM 'sees' the contents of the file before we start editing, so it is normal that it would + " start generating whatever we have given it via the extra context. but on the other hand, it's not very + " helpful to re-generate the same code that is already there + + " truncate the suggestion if the first line is empty + if len(s:content) == 1 && s:content[0] == "" + let s:content = [""] + endif + + " ... and the next lines are repeated + if len(s:content) > 1 && s:content[0] == "" && s:content[1:] == getline(s:pos_y + 1, s:pos_y + len(s:content) - 1) + let s:content = [""] + endif + + " truncate the suggestion if it repeats the suffix + if len(s:content) == 1 && s:content[0] == s:line_cur_suffix + let s:content = [""] + endif + + " find the first non-empty line (strip whitespace) + let l:cmp_y = s:pos_y + 1 + while l:cmp_y < line('$') && getline(l:cmp_y) =~? '^\s*$' + let l:cmp_y += 1 + endwhile + + if (s:line_cur_prefix . s:content[0]) == getline(l:cmp_y) + " truncate the suggestion if it repeats the next line + if len(s:content) == 1 + let s:content = [""] + endif + + " ... or if the second line of the suggestion is the prefix of line l:cmp_y + 1 + if len(s:content) == 2 && s:content[-1] == getline(l:cmp_y + 1)[:len(s:content[-1]) - 1] + let s:content = [""] + endif + + " ... or if the middle chunk of lines of the suggestion is the same as [l:cmp_y + 1, l:cmp_y + len(s:content) - 1) + if len(s:content) > 2 && join(s:content[1:-1], "\n") == join(getline(l:cmp_y + 1, l:cmp_y + len(s:content) - 1), "\n") + let s:content = [""] + endif + endif + + " keep only lines that have the same or larger whitespace prefix as s:line_cur_prefix + "let l:indent = strlen(matchstr(s:line_cur_prefix, '^\s*')) + "for i in range(1, len(s:content) - 1) + " if strlen(matchstr(s:content[i], '^\s*')) < l:indent + " let s:content = s:content[:i - 1] + " break + " endif + "endfor + + let s:pos_dx = len(s:content[-1]) + + let s:content[-1] .= s:line_cur_suffix + + call llama#fim_cancel() + + " display virtual text with the suggestion + let l:bufnr = bufnr('%') + + if s:ghost_text_nvim + let l:id_vt_fim = nvim_create_namespace('vt_fim') + endif + + " construct the info message + if g:llama_config.show_info > 0 && l:has_info + let l:prefix = ' ' + + if l:truncated + let l:info = printf("%s | WARNING: the context is full: %d / %d, increase the server context size or reduce g:llama_config.ring_n_chunks", + \ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim', + \ l:n_cached, l:n_ctx + \ ) + else + let l:info = printf("%s | c: %d / %d, r: %d / %d, e: %d, q: %d / 16 | p: %d (%.2f ms, %.2f t/s) | g: %d (%.2f ms, %.2f t/s) | t: %.2f ms", + \ g:llama_config.show_info == 2 ? l:prefix : 'llama.vim', + \ l:n_cached, l:n_ctx, len(s:ring_chunks), g:llama_config.ring_n_chunks, s:ring_n_evict, len(s:ring_queued), + \ l:n_prompt, l:t_prompt_ms, l:s_prompt, + \ l:n_predict, l:t_predict_ms, l:s_predict, + \ 1000.0 * reltimefloat(reltime(s:t_fim_start)) + \ ) + endif + + if g:llama_config.show_info == 1 + " display the info in the statusline + let &statusline = l:info + let l:info = '' + endif + endif + + " display the suggestion and append the info to the end of the first line + if s:ghost_text_nvim + call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, { + \ 'virt_text': [[s:content[0], 'llama_hl_hint'], [l:info, 'llama_hl_info']], + \ 'virt_text_win_col': virtcol('.') - 1 + \ }) + + call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, 0, { + \ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hl_hint']]}), + \ 'virt_text_win_col': virtcol('.') + \ }) + elseif s:ghost_text_vim + let l:new_suffix = s:content[0] + if !empty(l:new_suffix) + call prop_add(s:pos_y, s:pos_x + 1, { + \ 'type': s:hlgroup_hint, + \ 'text': l:new_suffix + \ }) + endif + for line in s:content[1:] + call prop_add(s:pos_y, 0, { + \ 'type': s:hlgroup_hint, + \ 'text': line, + \ 'text_padding_left': s:get_indent(line), + \ 'text_align': 'below' + \ }) + endfor + if !empty(l:info) + call prop_add(s:pos_y, 0, { + \ 'type': s:hlgroup_info, + \ 'text': l:info, + \ 'text_padding_left': col('$'), + \ 'text_wrap': 'truncate' + \ }) + endif + endif + + " setup accept shortcuts + inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR> + inoremap <buffer> <S-Tab> <C-O>:call llama#fim_accept(v:true)<CR> + + let s:hint_shown = v:true +endfunction + +function! s:fim_on_exit(job_id, exit_code, event = v:null) + if a:exit_code != 0 + echom "Job failed with exit code: " . a:exit_code + endif + + let s:current_job = v:null +endfunction diff --git a/llama.cpp/examples/lookahead/CMakeLists.txt b/llama.cpp/examples/lookahead/CMakeLists.txt new file mode 100644 index 0000000..3468613 --- /dev/null +++ b/llama.cpp/examples/lookahead/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-lookahead) +add_executable(${TARGET} lookahead.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/lookahead/README.md b/llama.cpp/examples/lookahead/README.md new file mode 100644 index 0000000..c82de2a --- /dev/null +++ b/llama.cpp/examples/lookahead/README.md @@ -0,0 +1,13 @@ +# llama.cpp/examples/lookahead + +Demonstration of lookahead decoding technique: + +https://lmsys.org/blog/2023-11-21-lookahead-decoding/ + +More info: https://github.com/ggml-org/llama.cpp/pull/4207 + +Sample command: + +```bash +llama-lookahead -hf ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF -p "// network server implemented in C\n// author: Peter Hacker\n\n#include" -e -ngl 99 -t 4 -n 512 -c 4096 -kvu +``` diff --git a/llama.cpp/examples/lookahead/lookahead.cpp b/llama.cpp/examples/lookahead/lookahead.cpp new file mode 100644 index 0000000..aa6efa6 --- /dev/null +++ b/llama.cpp/examples/lookahead/lookahead.cpp @@ -0,0 +1,480 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" + +#include <cstdio> +#include <string> +#include <vector> +#include <algorithm> + +struct ngram_data { + bool active = false; + + llama_seq_id seq_id = -1; + + std::vector<int> i_batch; + + std::vector<llama_token> tokens; +}; + +// n-gram container +struct ngram_container { + ngram_container(int n_vocab, int N, int G) { + cnt.resize(n_vocab); + head.resize(n_vocab); + tokens.resize(n_vocab * G * (N - 1)); + } + + int n_total = 0; + + std::vector<int> cnt; + std::vector<int> head; + + // [n_vocab][G][N - 1] + // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1 + std::vector<llama_token> tokens; +}; + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + const int W = 15; // lookahead window + const int N = 5; // n-gram size + const int G = 15; // max verification n-grams + + // lookahead requires W + G + 1 sequences for parallel Jacobi decoding + params.n_parallel = W + G + 1; + + // unified KV cache is required for coupled sequences in batch splitting + params.kv_unified = true; + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + // load the target model + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + auto * mem = llama_get_memory(ctx); + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // Tokenize the prompt + std::vector<llama_token> inp; + std::vector<llama_token> all; + + inp = common_tokenize(ctx, params.prompt, true, true); + all = inp; + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + // eval the prompt + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + + for (int s = 1; s < W + G + 1; ++s) { + llama_memory_seq_cp(mem, 0, s, -1, -1); + } + + const auto t_enc_end = ggml_time_us(); + + int n_predict = 0; + int n_accept = 0; + + int n_past = inp.size(); + + llama_token id = 0; + + // used to determine end of generation + bool has_eos = false; + + // for each decoded batch, we have at most W + G + 1 distinct sequences: + // seq_id == 0 : the current input token + // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations + // seq_id [W + 1, W + G] : verification n-grams + llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1); + + // target model sampling context + struct common_sampler * smpl = common_sampler_init(model, params.sampling); + + // verification n-grams + std::vector<ngram_data> ngrams_cur(G); + + // tokens for the past N - 1 Jacobi iterations + std::vector<llama_token> tokens_j_prev(W); + std::vector<std::vector<llama_token>> tokens_j(N - 1); + for (int j = 0; j < N - 1; j++) { + tokens_j[j].resize(W); + + for (int i = 0; i < W; i++) { + // there are different ways to init these tokens + if (0) { + // initialize randomly from the prompt tokens + tokens_j[j][i] = all[1 + rand() % (all.size() - 1)]; + } else { + // initialize with a sequence of increasing numbers + tokens_j[j][i] = 100 + i; + } + } + } + + std::vector<llama_seq_id> seq_id_look; + + // the input token belongs both to all sequences + std::vector<llama_seq_id> seq_id_all(W + G + 1); + for (int i = 0; i < W + G + 1; i++) { + seq_id_all[i] = i; + } + + // here we keep adding new n-grams as we go + ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G); + + const auto t_dec_start = ggml_time_us(); + + // sample first token + { + id = common_sampler_sample(smpl, ctx, 0); + + common_sampler_accept(smpl, id, true); + + { + const std::string token_str = common_token_to_piece(ctx, id); + + LOG("%s", token_str.c_str()); + fflush(stdout); + } + } + + while (true) { + // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/ + // + // Example for W = 5, N = 4, G = 2: + // (I = input, L = lookahead, V = verification) + // + // Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 + // T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0 + // Info: I L L L L L L L L L L L L L L V V V V V V + // Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past) + // Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 + // --------------------------------------------------------------------- + // Seq: 0 + // 1 1 1 + // 2 2 2 2 + // 3 3 3 3 3 + // 4 4 4 4 4 4 + // 5 5 5 5 5 5 5 + // 6 6 6 6 + // 7 7 7 7 + // --------------------------------------------------------------------- + // | | | | | | | | | | | + // V V V V V | | | | | | + // j_tokens | | | | | | + // V V V V V V + // id + { + common_batch_clear(batch); + + // current token - first token of the first level + common_batch_add(batch, id, n_past, seq_id_all, true); + + // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation + { + const int g_cur = ngrams_observed.cnt[id]; + + ngrams_cur.resize(g_cur); + for (int g = 0; g < g_cur; g++) { + ngrams_cur[g].active = true; + ngrams_cur[g].tokens.resize(N); + ngrams_cur[g].i_batch.resize(N); + ngrams_cur[g].seq_id = W + 1 + g; + ngrams_cur[g].i_batch[0] = 0; + ngrams_cur[g].tokens [0] = id; + } + + for (int j = 0; j < N - 1; j++) { + for (int g = 0; g < g_cur; g++) { + const int idx = id*(N - 1)*G + g*(N - 1); + + const llama_token t = ngrams_observed.tokens[idx + j]; + + ngrams_cur[g].tokens [j + 1] = t; + ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; + + common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); + } + } + } + + // fill the remaining W - 1 tokens for the first level + for (int i = 1; i < W; i++) { + seq_id_look.resize(W - i); + for (int j = 0; j < W - i; j++) { + seq_id_look[j] = i + j + 1; + } + + common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); + } + + // fill the rest of the levels + for (int j = 1; j < N - 1; j++) { + for (int i = 0; i < W; i++) { + common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); + } + } + } + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); + return 1; + } + + int seq_id_best = 0; + + for (int v = 0; v < N; ++v) { + int i_batch = 0; + + // if no active ngrams are left, it means the sampled token does not pass the verification + if (v > 0) { + for (int g = 0; g < (int) ngrams_cur.size(); g++) { + if (ngrams_cur[g].active) { + i_batch = ngrams_cur[g].i_batch[v]; + seq_id_best = ngrams_cur[g].seq_id; + + ++n_accept; + break; + } + } + + // no more matches -> create a new batch + if (i_batch == 0) { + break; + } + } + + // sample the next token + id = common_sampler_sample(smpl, ctx, i_batch); + + common_sampler_accept(smpl, id, true); + + // print + { + const std::string token_str = common_token_to_piece(ctx, id); + + if (v == 0) { + LOG("%s", token_str.c_str()); + } else { + // print light cyan + LOG("\033[0;96m%s\033[0m", token_str.c_str()); + } + fflush(stdout); + + if (llama_vocab_is_eog(vocab, id)) { + has_eos = true; + } + + all.push_back(id); + } + + ++n_predict; + ++n_past; + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + + // verify across active n-grams + for (int g = 0; g < (int) ngrams_cur.size(); g++) { + if (ngrams_cur[g].active) { + if (v == N - 1) { + ngrams_cur[g].active = false; + } else { + if (id != ngrams_cur[g].tokens[v + 1]) { + ngrams_cur[g].active = false; + } + } + } + } + + // print known n-grams starting with token id (debug) + if (0 && v == 0) { + if (ngrams_observed.cnt[id] > 0) { + LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str()); + } + + for (int i = 0; i < ngrams_observed.cnt[id]; i++) { + LOG(" - ngram %2d: ", i); + + const int idx = id*(N - 1)*G + i*(N - 1); + + for (int j = 0; j < N - 1; j++) { + const std::string token_str = common_token_to_piece(ctx, ngrams_observed.tokens[idx + j]); + + LOG("%s", token_str.c_str()); + } + + LOG("\n"); + } + } + + // update lookahead tokens + { + for (int i = 0; i < W; i++) { + tokens_j_prev[i] = tokens_j[0][i]; + } + + for (int j = 0; j < N - 2; j++) { + tokens_j[j] = tokens_j[j + 1]; + } + + if (v == 0) { + // sample from the last level + for (int i = 0; i < W; i++) { + tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + } + } else { + for (int i = 0; i < W; i++) { + // there are different ways to init these tokens + if (0) { + // random init + tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)]; + } else { + // init from the previous level + tokens_j[N - 2][i] = tokens_j[0][i]; + } + } + } + } + + // update observed ngrams + if (v == 0) { + // the first token of the n-gram is determined by the index in the container so it is not stored + std::vector<llama_token> ngram(N - 1); + + // n-gram generation + // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518 + for (int f = 0; f < W; ++f) { + const int ft = tokens_j_prev[f]; // first token of the n-gram + + for (int j = 0; j < N - 1; ++j) { + ngram[j] = tokens_j[j][f]; + } + + // filter-out repeating n-grams + { + bool is_unique = true; + + for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) { + const int idx = ft*(N - 1)*G + k*(N - 1); + + bool is_match = true; + for (int j = 0; j < N - 1; ++j) { + if (ngrams_observed.tokens[idx + j] != ngram[j]) { + is_match = false; + break; + } + } + + if (is_match) { + is_unique = false; + break; + } + } + + if (!is_unique) { + continue; + } + } + + const int head = ngrams_observed.head[ft]; + const int idx = ft*(N - 1)*G + head*(N - 1); + + for (int i = 0; i < N - 1; i++) { + ngrams_observed.tokens[idx + i] = ngram[i]; + } + + ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1); + ngrams_observed.head[ft] = (head + 1) % G; + + ngrams_observed.n_total++; + } + } + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + + // KV cache management + // if no verification token matched, we simply remove all cells from this batch -> no fragmentation + llama_memory_seq_rm(mem, -1, n_past, -1); + + if (seq_id_best != 0) { + // if a verification token matched, we keep the best sequence and remove the rest + // this leads to some KV cache fragmentation + llama_memory_seq_keep(mem, seq_id_best); + llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1); + llama_memory_seq_rm (mem, seq_id_best, -1, -1); + + for (int s = 1; s < W + G + 1; ++s) { + llama_memory_seq_cp(mem, 0, s, -1, -1); + } + } + } + + auto t_dec_end = ggml_time_us(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("W = %2d\n", W); + LOG_INF("N = %2d\n", N); + LOG_INF("G = %2d\n", G); + LOG_INF("\n"); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_accept = %d\n", n_accept); + + LOG_INF("\n"); + common_perf_print(ctx, smpl); + + common_sampler_free(smpl); + + llama_batch_free(batch); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/llama.cpp/examples/lookup/CMakeLists.txt b/llama.cpp/examples/lookup/CMakeLists.txt new file mode 100644 index 0000000..fba78ce --- /dev/null +++ b/llama.cpp/examples/lookup/CMakeLists.txt @@ -0,0 +1,23 @@ +set(TARGET llama-lookup) +add_executable(${TARGET} lookup.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-lookup-create) +add_executable(${TARGET} lookup-create.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-lookup-merge) +add_executable(${TARGET} lookup-merge.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-lookup-stats) +add_executable(${TARGET} lookup-stats.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/lookup/README.md b/llama.cpp/examples/lookup/README.md new file mode 100644 index 0000000..07d7384 --- /dev/null +++ b/llama.cpp/examples/lookup/README.md @@ -0,0 +1,12 @@ +# llama.cpp/examples/lookup + +Demonstration of Prompt Lookup Decoding + +https://github.com/apoorvumang/prompt-lookup-decoding + +The key parameters for lookup decoding are `ngram_min`, `ngram_max` and `n_draft`. The first two determine the size of the ngrams to search for in the prompt for a match. The latter specifies how many subsequent tokens to draft if a match is found. + +More info: + +https://github.com/ggml-org/llama.cpp/pull/4484 +https://github.com/ggml-org/llama.cpp/issues/4226 diff --git a/llama.cpp/examples/lookup/lookup-create.cpp b/llama.cpp/examples/lookup/lookup-create.cpp new file mode 100644 index 0000000..f7b6ea1 --- /dev/null +++ b/llama.cpp/examples/lookup/lookup-create.cpp @@ -0,0 +1,40 @@ +#include "arg.h" +#include "common.h" +#include "ngram-cache.h" +#include "llama.h" + +#include <string> +#include <vector> + +int main(int argc, char ** argv){ + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { + return 1; + } + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + GGML_ASSERT(model != nullptr); + + // tokenize the prompt + std::vector<llama_token> inp; + inp = common_tokenize(ctx, params.prompt, true, true); + fprintf(stderr, "%s: tokenization done\n", __func__); + + common_ngram_cache ngram_cache; + common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); + fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str()); + + common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static); + + return 0; +} diff --git a/llama.cpp/examples/lookup/lookup-merge.cpp b/llama.cpp/examples/lookup/lookup-merge.cpp new file mode 100644 index 0000000..6871c0f --- /dev/null +++ b/llama.cpp/examples/lookup/lookup-merge.cpp @@ -0,0 +1,47 @@ +#include "ggml.h" +#include "llama.h" +#include "common.h" +#include "ngram-cache.h" + +#include <cstdint> +#include <cstdio> +#include <fstream> +#include <iostream> +#include <string> +#include <unordered_map> +#include <vector> + +static void print_usage(char* argv0) { + fprintf(stderr, "Merges multiple lookup cache files into a single one.\n"); + fprintf(stderr, "Usage: %s [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n", argv0); +} + +int main(int argc, char ** argv){ + if (argc < 3) { + print_usage(argv[0]); + exit(1); + } + + std::vector<std::string> args; + args.resize(argc-1); + for (int i = 0; i < argc-1; ++i) { + args[i] = argv[i+1]; + if (args[i] == "-h" || args[i] == "--help") { + print_usage(argv[0]); + exit(0); + } + } + + fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str()); + common_ngram_cache ngram_cache_merged = common_ngram_cache_load(args[0]); + + for (size_t i = 1; i < args.size()-1; ++i) { + fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str()); + common_ngram_cache ngram_cache = common_ngram_cache_load(args[i]); + + common_ngram_cache_merge(ngram_cache_merged, ngram_cache); + } + + fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str()); + common_ngram_cache_save(ngram_cache_merged, args.back()); +} diff --git a/llama.cpp/examples/lookup/lookup-stats.cpp b/llama.cpp/examples/lookup/lookup-stats.cpp new file mode 100644 index 0000000..ae28b2e --- /dev/null +++ b/llama.cpp/examples/lookup/lookup-stats.cpp @@ -0,0 +1,157 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "ngram-cache.h" +#include "llama.h" +#include "ggml.h" + +#include <cstdint> +#include <cstdio> +#include <cinttypes> +#include <fstream> +#include <string> +#include <vector> + +int main(int argc, char ** argv){ + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { + return 1; + } + + common_init(); + + const int n_draft = params.speculative.n_max; + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + auto llama_init = common_init_from_params(params); + + llama_context * ctx = llama_init->context(); + + // tokenize the prompt + std::vector<llama_token> inp; + inp = common_tokenize(ctx, params.prompt, true, true); + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + + int64_t t_draft_flat_us = 0; + int64_t t_draft_us = 0; + + { + const int64_t t_start_draft_us = ggml_time_us(); + + if (!params.speculative.lookup_cache_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static); + } catch (std::ifstream::failure const &) { + LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str()); + exit(1); + } + } + + if (!params.speculative.lookup_cache_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic); + } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program + } + + t_draft_flat_us += ggml_time_us() - t_start_draft_us; + } + + const int n_input = inp.size(); + const int n_ctx = llama_n_ctx(ctx); + + int n_drafted = 0; + int n_accept = 0; + + const int64_t t_start_ms = ggml_time_ms(); + + // Iterate over input tokens in chunks of size n_ctx. + // Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility. + for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) { + const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx); + std::vector<llama_token> pseudo_output; + pseudo_output.push_back(inp_slice[0]); + + while ((int) pseudo_output.size() < n_ctx) { + // Simulate drafting and decoding from draft: + std::vector<llama_token> draft; + draft.push_back(pseudo_output.back()); + + { + const int64_t t_start_draft_us = ggml_time_us(); + common_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + + n_drafted += draft.size() - 1; + + for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) { + const llama_token ground_truth = inp_slice[pseudo_output.size()]; + const llama_token drafted = draft[j]; + + if (ground_truth != drafted) { + break; + } + + ++n_accept; + pseudo_output.push_back(ground_truth); + + { + const int64_t t_start_draft_us = ggml_time_us(); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + } + + // After each simulated batch decoding simulate the sampling of a single token: + if ((int) pseudo_output.size() < n_ctx) { + pseudo_output.push_back(inp_slice[pseudo_output.size()]); + { + const int64_t t_start_draft_us = ggml_time_us(); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + } + + draft.erase(draft.begin()); + + } + if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) { + const int64_t t_now_ms = ggml_time_ms(); + const int64_t eta_ms = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start; + const int64_t eta_min = eta_ms / (60*1000); + const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000; + + LOG_INF("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s); + } + + // After each chunk, update the dynamic ngram cache with the context ngram cache: + common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); + ngram_cache_context.clear(); + } + + LOG("\n"); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_input - n_input % n_ctx); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); + LOG_INF("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", + t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us)); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/llama.cpp/examples/lookup/lookup.cpp b/llama.cpp/examples/lookup/lookup.cpp new file mode 100644 index 0000000..c7552dd --- /dev/null +++ b/llama.cpp/examples/lookup/lookup.cpp @@ -0,0 +1,242 @@ +#include "arg.h" +#include "ggml.h" +#include "common.h" +#include "ngram-cache.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" + +#include <cstdint> +#include <cstdio> +#include <fstream> +#include <string> +#include <vector> + +int main(int argc, char ** argv){ + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { + return 1; + } + + common_init(); + + // max. number of additional tokens to draft if match is found + const int n_draft = params.speculative.n_max; + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // tokenize the prompt + std::vector<llama_token> inp; + inp = common_tokenize(ctx, params.prompt, true, true); + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + int64_t t_draft_flat_us = 0; + int64_t t_draft_us = 0; + + { + // Fill up context ngram cache with tokens from user input: + const int64_t t_start_draft_us = ggml_time_us(); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false); + + if (!params.speculative.lookup_cache_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static); + } catch (std::ifstream::failure const &) { + LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str()); + exit(1); + } + } + + if (!params.speculative.lookup_cache_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic); + } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program + } + + t_draft_flat_us += ggml_time_us() - t_start_draft_us; + } + + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx, id).c_str()); + } + + fflush(stderr); + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + + const auto t_enc_end = ggml_time_us(); + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + int n_past = inp.size(); + + bool has_eos = false; + + struct common_sampler * smpl = common_sampler_init(model, params.sampling); + + std::vector<llama_token> draft; + + llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1); + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // print current draft sequence + LOG_DBG("drafted %s\n", string_from(ctx, draft).c_str()); + + int i_dft = 0; + while (true) { + // sample from the target model + llama_token id = common_sampler_sample(smpl, ctx, i_dft); + + common_sampler_accept(smpl, id, true); + + const std::string token_str = common_token_to_piece(ctx, id); + + if (!params.use_color) { + LOG("%s", token_str.c_str()); + } + + if (llama_vocab_is_eog(vocab, id)) { + has_eos = true; + } + + ++n_predict; + + // check if the target token matches the draft + if (i_dft < (int) draft.size() && id == draft[i_dft]) { + LOG_DBG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str()); + ++n_accept; + ++n_past; + ++i_dft; + inp.push_back(id); + { + // Update context ngram cache with the newly accepted token: + const int64_t t_start_draft_us = ggml_time_us(); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + + if (params.use_color) { + // color accepted draft token + LOG("\033[34m%s\033[0m", token_str.c_str()); + fflush(stdout); + } + continue; + } + + if (params.use_color) { + LOG("%s", token_str.c_str()); + } + fflush(stdout); + + + LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + + draft.clear(); + draft.push_back(id); + inp.push_back(id); + { + // Update context ngram cache with the newly accepted token: + const int64_t t_start_draft_us = ggml_time_us(); + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false); + t_draft_us += ggml_time_us() - t_start_draft_us; + } + break; + } + + if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) { + break; + } + + // KV cache management + // clean the cache of draft tokens that weren't accepted + llama_memory_seq_rm(llama_get_memory(ctx), 0, n_past, -1); + + common_batch_clear(batch_tgt); + common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + + // Draft already contains a single token sampled from the model: + GGML_ASSERT(draft.size() == 1); + GGML_ASSERT(draft[0] == inp.back()); + const int64_t t_start_draft_us = ggml_time_us(); + + common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); + + for (size_t i = 1; i < draft.size(); ++i) { + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + } + + t_draft_us += ggml_time_us() - t_start_draft_us; + n_drafted += draft.size() - 1; + + llama_decode(ctx, batch_tgt); + ++n_past; + + draft.erase(draft.begin()); + } + + auto t_dec_end = ggml_time_us(); + + // Update dynamic ngram cache with context ngram cache and save it to disk: + common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); + common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3); + LOG_INF("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n", + t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us)); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\ntarget:\n\n"); + common_perf_print(ctx, smpl); + + common_sampler_free(smpl); + + llama_batch_free(batch_tgt); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/llama.cpp/examples/model-conversion/.gitignore b/llama.cpp/examples/model-conversion/.gitignore new file mode 100644 index 0000000..4512275 --- /dev/null +++ b/llama.cpp/examples/model-conversion/.gitignore @@ -0,0 +1,3 @@ +.model_name +data +ppl diff --git a/llama.cpp/examples/model-conversion/Makefile b/llama.cpp/examples/model-conversion/Makefile new file mode 100644 index 0000000..342de63 --- /dev/null +++ b/llama.cpp/examples/model-conversion/Makefile @@ -0,0 +1,232 @@ +MAKEFLAGS += --no-print-directory + +define validate_model_path + @if [ -z "$(MODEL_PATH)" ]; then \ + echo "Error: MODEL_PATH must be provided either as:"; \ + echo " 1. Environment variable: export MODEL_PATH=/path/to/model"; \ + echo " 2. Command line argument: make $(1) MODEL_PATH=/path/to/model"; \ + exit 1; \ + fi +endef + +define validate_embedding_model_path + @if [ -z "$(EMBEDDING_MODEL_PATH)" ]; then \ + echo "Error: EMBEDDING_MODEL_PATH must be provided either as:"; \ + echo " 1. Environment variable: export EMBEDDING_MODEL_PATH=/path/to/model"; \ + echo " 2. Command line argument: make $(1) EMBEDDING_MODEL_PATH=/path/to/model"; \ + exit 1; \ + fi +endef + +define quantize_model + @CONVERTED_MODEL="$(1)" QUANTIZED_TYPE="$(QUANTIZED_TYPE)" \ + TOKEN_EMBD_TYPE="$(TOKEN_EMBD_TYPE)" OUTPUT_TYPE="$(OUTPUT_TYPE)" \ + ./scripts/utils/quantize.sh "$(1)" "$(QUANTIZED_TYPE)" "$(TOKEN_EMBD_TYPE)" "$(OUTPUT_TYPE)" + @echo "Export the quantized model path to $(2) variable in your environment" +endef + +DEVICE ?= auto + +### +### Casual Model targets/recipes +### +causal-convert-model-bf16: OUTTYPE=bf16 +causal-convert-model-bf16: causal-convert-model + +causal-convert-model-debug: DEBUG=--debug +causal-convert-model-debug: causal-convert-model + +causal-convert-model: + $(call validate_model_path,causal-convert-model) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/causal/convert-model.sh $(DEBUG) + +causal-convert-mm-model-bf16: OUTTYPE=bf16 +causal-convert-mm-model-bf16: MM_OUTTYPE=f16 +causal-convert-mm-model-bf16: causal-convert-mm-model + +causal-convert-mm-model: + $(call validate_model_path,causal-convert-mm-model) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/causal/convert-model.sh + + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(MM_OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/causal/convert-model.sh --mmproj + +causal-run-original-model: + $(call validate_model_path,causal-run-original-model) + @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/run-org-model.py --device "$(DEVICE)" + +causal-run-converted-model: + @CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/causal/run-converted-model.sh + +causal-verify-logits: causal-run-original-model causal-run-converted-model + @MODEL_PATH="$(MODEL_PATH)" ./scripts/causal/compare-logits.py + @MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH} + +causal-run-original-embeddings: + @./scripts/causal/run-casual-gen-embeddings-org.py + +causal-run-converted-embeddings: + @./scripts/causal/run-converted-model-embeddings-logits.sh + +causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-embeddings + @./scripts/causal/compare-embeddings-logits.sh + +causal-inspect-original-model: + @./scripts/utils/inspect-org-model.py + +causal-inspect-converted-model: + @./scripts/utils/inspect-converted-model.sh + +causal-start-embedding-server: + @./scripts/utils/run-embedding-server.sh ${CONVERTED_MODEL} + +causal-curl-embedding-endpoint: causal-run-original-embeddings + @./scripts/utils/curl-embedding-server.sh | ./scripts/causal/compare-embeddings-logits.sh + +causal-quantize-Q8_0: QUANTIZED_TYPE = Q8_0 +causal-quantize-Q8_0: causal-quantize-model + +causal-quantize-Q4_0: QUANTIZED_TYPE = Q4_0 +causal-quantize-Q4_0: causal-quantize-model + +# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the +# token embedding and output types to Q8_0 instead of the default Q6_K. +causal-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0 +causal-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0 +causal-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0 +causal-quantize-qat-Q4_0: causal-quantize-model + +causal-quantize-model: + $(call quantize_model,$(CONVERTED_MODEL),QUANTIZED_MODEL) + +causal-run-quantized-model: + @QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/causal/run-converted-model.sh ${QUANTIZED_MODEL} + + +### +### Embedding Model targets/recipes +### + +embedding-convert-model-bf16: OUTTYPE=bf16 +embedding-convert-model-bf16: embedding-convert-model + +embedding-convert-model: + $(call validate_embedding_model_path,embedding-convert-model) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/embedding/convert-model.sh + +embedding-convert-model-st: + $(call validate_embedding_model_path,embedding-convert-model-st) + @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ + ./scripts/embedding/convert-model.sh -st + +embedding-run-original-model: + $(call validate_embedding_model_path,embedding-run-original-model) + @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" \ + USE_SENTENCE_TRANSFORMERS="$(USE_SENTENCE_TRANSFORMERS)" \ + ./scripts/embedding/run-original-model.py \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(USE_SENTENCE_TRANSFORMERS),--use-sentence-transformers) + +embedding-run-original-model-st: USE_SENTENCE_TRANSFORMERS=1 +embedding-run-original-model-st: embedding-run-original-model + +embedding-run-converted-model: + @./scripts/embedding/run-converted-model.sh $(CONVERTED_EMBEDDING_MODEL) \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") \ + $(if $(EMBD_NORMALIZE),--embd-normalize "$(EMBD_NORMALIZE)") + +embedding-verify-logits: embedding-run-original-model embedding-run-converted-model + @./scripts/embedding/compare-embeddings-logits.sh \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + +embedding-verify-logits-st: embedding-run-original-model-st embedding-run-converted-model + @./scripts/embedding/compare-embeddings-logits.sh \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + +embedding-inspect-original-model: + $(call validate_embedding_model_path,embedding-inspect-original-model) + @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} + +embedding-inspect-converted-model: + @CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL} + +embedding-start-embedding-server: + @./scripts/utils/run-embedding-server.sh ${CONVERTED_EMBEDDING_MODEL} + +embedding-curl-embedding-endpoint: + @./scripts/utils/curl-embedding-server.sh | ./scripts/embedding/compare-embeddings-logits.sh + +embedding-quantize-Q8_0: QUANTIZED_TYPE = Q8_0 +embedding-quantize-Q8_0: embedding-quantize-model + +embedding-quantize-Q4_0: QUANTIZED_TYPE = Q4_0 +embedding-quantize-Q4_0: embedding-quantize-model + +# For Quantization Aware Trained (QAT) models in Q4_0 we explicitly set the +# token embedding and output types to Q8_0 instead of the default Q6_K. +embedding-quantize-qat-Q4_0: QUANTIZED_TYPE = Q4_0 +embedding-quantize-qat-Q4_0: TOKEN_EMBD_TYPE = Q8_0 +embedding-quantize-qat-Q4_0: OUTPUT_TYPE = Q8_0 +embedding-quantize-qat-Q4_0: embedding-quantize-model + +embedding-quantize-model: + $(call quantize_model,$(CONVERTED_EMBEDDING_MODEL),QUANTIZED_EMBEDDING_MODEL) + +embedding-run-quantized-model: + @./scripts/embedding/run-converted-model.sh $(QUANTIZED_EMBEDDING_MODEL) \ + $(if $(PROMPTS_FILE),--prompts-file "$(PROMPTS_FILE)") + +### +### Perplexity targets/recipes +### +perplexity-data-gen: + CONVERTED_MODEL="$(CONVERTED_MODEL)" ./scripts/utils/perplexity-gen.sh + +perplexity-run-full: + QUANTIZED_MODEL="$(QUANTIZED_MODEL)" LOOGITS_FILE="$(LOGITS_FILE)" \ + ./scripts/utils/perplexity-run.sh + +perplexity-run: + QUANTIZED_MODEL="$(QUANTIZED_MODEL)" ./scripts/utils/perplexity-run-simple.sh + +### +### HuggingFace targets/recipes +### + +hf-create-model: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" + +hf-create-model-dry-run: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -d + +hf-create-model-embedding: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e + +hf-create-model-embedding-dry-run: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -e -d + +hf-create-model-private: + @./scripts/utils/hf-create-model.py -m "${MODEL_NAME}" -ns "${NAMESPACE}" -b "${ORIGINAL_BASE_MODEL}" -p + +hf-upload-gguf-to-model: + @./scripts/utils/hf-upload-gguf-model.py -m "${MODEL_PATH}" -r "${REPO_ID}" -o "${NAME_IN_REPO}" + +hf-create-collection: + @./scripts/utils/hf-create-collection.py -n "${NAME}" -d "${DESCRIPTION}" -ns "${NAMESPACE}" + +hf-add-model-to-collection: + @./scripts/utils/hf-add-model-to-collection.py -c "${COLLECTION}" -m "${MODEL}" + + +.PHONY: clean +clean: + @${RM} -rf data .converted_embedding_model.txt .converted_model.txt .embedding_model_name.txt .model_name.txt + diff --git a/llama.cpp/examples/model-conversion/README.md b/llama.cpp/examples/model-conversion/README.md new file mode 100644 index 0000000..637870a --- /dev/null +++ b/llama.cpp/examples/model-conversion/README.md @@ -0,0 +1,408 @@ +# Model Conversion Example +This directory contains scripts and code to help in the process of converting +HuggingFace PyTorch models to GGUF format. + +The motivation for having this is that the conversion process can often be an +iterative process, where the original model is inspected, converted, updates +made to llama.cpp, converted again, etc. Once the model has been converted it +needs to be verified against the original model, and then optionally quantified, +and in some cases perplexity checked of the quantized model. And finally the +model/models need to the ggml-org on Hugging Face. This tool/example tries to +help with this process. + +> 📝 **Note:** When adding a new model from an existing family, verify the +> previous version passes logits verification first. Existing models can have +> subtle numerical differences that don't affect generation quality but cause +> logits mismatches. Identifying these upfront whether they exist in llama.cpp, +> the conversion script, or in an upstream implementation, can save significant +> debugging time. + +### Overview +The idea is that the makefile targets and scripts here can be used in the +development/conversion process assisting with things like: + +* inspect/run the original model to figure out how it works +* convert the original model to GGUF format +* inspect/run the converted model +* verify the logits produced by the original model and the converted model +* quantize the model to GGUF format +* run perplexity evaluation to verify that the quantized model is performing + as expected +* upload the model to HuggingFace to make it available for others + +## Setup +Create virtual python environment +```console +$ python3.11 -m venv venv +$ source venv/bin/activate +(venv) $ pip install -r requirements.txt +``` + +## Causal Language Model Conversion +This section describes the steps to convert a causal language model to GGUF and +to verify that the conversion was successful. + +### Download the original model +First, clone the original model to some local directory: +```console +$ mkdir models && cd models +$ git clone https://huggingface.co/user/model_name +$ cd model_name +$ git lfs install +$ git lfs pull +``` + +### Set the MODEL_PATH +The path to the downloaded model can be provided in two ways: + +**Option 1: Environment variable (recommended for iterative development)** +```console +export MODEL_PATH=~/work/ai/models/some_model +``` + +**Option 2: Command line argument (for one-off tasks)** +```console +make causal-convert-model MODEL_PATH=~/work/ai/models/some_model +``` + +Command line arguments take precedence over environment variables when both are provided. + +In cases where the transformer implementation for the model has not been released +yet it is possible to set the environment variable `UNRELEASED_MODEL_NAME` which +will then cause the transformer implementation to be loaded explicitely and not +use AutoModelForCausalLM: +``` +export UNRELEASED_MODEL_NAME=SomeNewModel +``` + +### Inspecting the original tensors +```console +# Using environment variable +(venv) $ make causal-inspect-original-model + +# Or using command line argument +(venv) $ make causal-inspect-original-model MODEL_PATH=~/work/ai/models/some_model +``` + +### Running the original model +This is mainly to verify that the original model works, and to compare the output +from the converted model. +```console +# Using environment variable +(venv) $ make causal-run-original-model + +# Or using command line argument +(venv) $ make causal-run-original-model MODEL_PATH=~/work/ai/models/some_model +``` +This command will save two files to the `data` directory, one is a binary file +containing logits which will be used for comparison with the converted model +later, and the other is a text file which allows for manual visual inspection. + +### Model conversion +After updates have been made to [gguf-py](../../gguf-py) to add support for the +new model, the model can be converted to GGUF format using the following command: +```console +# Using environment variable +(venv) $ make causal-convert-model + +# Or using command line argument +(venv) $ make causal-convert-model MODEL_PATH=~/work/ai/models/some_model +``` + +### Inspecting the converted model +The converted model can be inspected using the following command: +```console +(venv) $ make causal-inspect-converted-model +``` + +### Running the converted model +```console +(venv) $ make causal-run-converted-model +``` + +### Model logits verfication +The following target will run the original model and the converted model and +compare the logits: +```console +(venv) $ make causal-verify-logits +``` + +### Quantizing the model +The causal model can be quantized to GGUF format using the following command: +```console +(venv) $ make causal-quantize-Q8_0 +Quantized model saved to: /path/to/quantized/model-Q8_0.gguf +Export the quantized model path to QUANTIZED_MODEL variable in your environment +``` +This will show the path to the quantized model in the terminal, which can then +be used to set the `QUANTIZED_MODEL` environment variable: +```console +export QUANTIZED_MODEL=/path/to/quantized/model-Q8_0.gguf +``` +Then the quantized model can be run using the following command: +```console +(venv) $ make causal-run-quantized-model +``` + +### Quantizing QAT (Quantization Aware Training) models +When quantizing to `Q4_0`, the default data type for the token embedding weights +will be `Q6_K`. For models that are going to be uploaded to ggml-org it is +recommended to use `Q8_0` instead for the embeddings and output tensors. +The reason is that although `Q6_K` is smaller in size, it requires more compute +to unpack, which can hurt performance during output generation when the entire +embedding matrix must be dequantized to compute vocabulary logits. `Q8_0` +provides practically full quality with better computational efficiency. +```console +(venv) $ make causal-quantize-qat-Q4_0 +``` + + +## Embedding Language Model Conversion + +### Download the original model +```console +$ mkdir models && cd models +$ git clone https://huggingface.co/user/model_name +$ cd model_name +$ git lfs install +$ git lfs pull +``` + +The path to the embedding model can be provided in two ways: + +**Option 1: Environment variable (recommended for iterative development)** +```console +export EMBEDDING_MODEL_PATH=~/path/to/embedding_model +``` + +**Option 2: Command line argument (for one-off tasks)** +```console +make embedding-convert-model EMBEDDING_MODEL_PATH=~/path/to/embedding_model +``` + +Command line arguments take precedence over environment variables when both are provided. + +### Running the original model +This is mainly to verify that the original model works and to compare the output +with the output from the converted model. +```console +# Using environment variable +(venv) $ make embedding-run-original-model + +# Or using command line argument +(venv) $ make embedding-run-original-model EMBEDDING_MODEL_PATH=~/path/to/embedding_model +``` +This command will save two files to the `data` directory, one is a binary +file containing logits which will be used for comparison with the converted +model, and the other is a text file which allows for manual visual inspection. + +#### Using SentenceTransformer with numbered layers +For models that have numbered SentenceTransformer layers (01_Pooling, 02_Dense, +03_Dense, 04_Normalize), these will be applied automatically when running the +converted model but currently there is a separate target to run the original +version: + +```console +# Run original model with SentenceTransformer (applies all numbered layers) +(venv) $ make embedding-run-original-model-st +``` + +This will use the SentenceTransformer library to load and run the model, which +automatically applies all the numbered layers in the correct order. This is +particularly useful when comparing with models that should include these +additional transformation layers beyond just the base model output. + +The type of normalization can be specified for the converted model but is not +strictly necessary as the verification uses cosine similarity and the magnitude +of the output vectors does not affect this. But the normalization type can be +specified as an argument to the target which might be useful for manual +inspection: +```console +(venv) $ make embedding-verify-logits-st EMBD_NORMALIZE=1 +``` +The original model will apply the normalization according to the normalization +layer specified in the modules.json configuration file. + +### Model conversion +After updates have been made to [gguf-py](../../gguf-py) to add support for the +new model the model can be converted to GGUF format using the following command: +```console +(venv) $ make embedding-convert-model +``` + +### Run the converted model +```console +(venv) $ make embedding-run-converted-model +``` + +### Model logits verfication +The following target will run the original model and the converted model (which +was done manually in the previous steps) and compare the logits: +```console +(venv) $ make embedding-verify-logits +``` + +For models with SentenceTransformer layers, use the `-st` verification target: +```console +(venv) $ make embedding-verify-logits-st +``` +This convenience target automatically runs both the original model with SentenceTransformer +and the converted model with pooling enabled, then compares the results. + +### llama-server verification +To verify that the converted model works with llama-server, the following +command can be used: +```console +(venv) $ make embedding-start-embedding-server +``` +Then open another terminal and set the `EMBEDDINGS_MODEL_PATH` environment +variable as this will not be inherited by the new terminal: +```console +(venv) $ make embedding-curl-embedding-endpoint +``` +This will call the `embedding` endpoing and the output will be piped into +the same verification script as used by the target `embedding-verify-logits`. + +The causal model can also be used to produce embeddings and this can be verified +using the following commands: +```console +(venv) $ make causal-start-embedding-server +``` +Then open another terminal and set the `MODEL_PATH` environment +variable as this will not be inherited by the new terminal: +```console +(venv) $ make casual-curl-embedding-endpoint +``` + +### Quantizing the model +The embedding model can be quantized to GGUF format using the following command: +```console +(venv) $ make embedding-quantize-Q8_0 +Quantized model saved to: /path/to/quantized/model-Q8_0.gguf +Export the quantized model path to QUANTIZED_EMBEDDING_MODEL variable in your environment +``` +This will show the path to the quantized model in the terminal, which can then +be used to set the `QUANTIZED_EMBEDDING_MODEL` environment variable: +```console +export QUANTIZED_EMBEDDING_MODEL=/path/to/quantized/model-Q8_0.gguf +``` +Then the quantized model can be run using the following command: +```console +(venv) $ make embedding-run-quantized-model +``` + +### Quantizing QAT (Quantization Aware Training) models +When quantizing to `Q4_0`, the default data type for the token embedding weights +will be `Q6_K`. For models that are going to be uploaded to ggml-org it is +recommended to use `Q8_0` instead for the embeddings and output tensors. +The reason is that although `Q6_K` is smaller in size, it requires more compute +to unpack, which can hurt performance during output generation when the entire +embedding matrix must be dequantized to compute vocabulary logits. `Q8_0` +provides practically full quality with better computational efficiency. +```console +(venv) $ make embedding-quantize-qat-Q4_0 +``` + +## Perplexity Evaluation + +### Simple perplexity evaluation +This allows to run the perplexity evaluation without having to generate a +token/logits file: +```console +(venv) $ make perplexity-run QUANTIZED_MODEL=~/path/to/quantized/model.gguf +``` +This will use the wikitext dataset to run the perplexity evaluation and +output the perplexity score to the terminal. This value can then be compared +with the perplexity score of the unquantized model. + +### Full perplexity evaluation +First use the converted, non-quantized, model to generate the perplexity evaluation +dataset using the following command: +```console +$ make perplexity-data-gen CONVERTED_MODEL=~/path/to/converted/model.gguf +``` +This will generate a file in the `data` directory named after the model and with +a `.kld` suffix which contains the tokens and the logits for the wikitext dataset. + +After the dataset has been generated, the perplexity evaluation can be run using +the quantized model: +```console +$ make perplexity-run-full QUANTIZED_MODEL=~/path/to/quantized/model-Qxx.gguf LOGITS_FILE=data/model.gguf.ppl +``` + +> 📝 **Note:** The `LOGITS_FILE` is the file generated by the previous command +> can be very large, so make sure you have enough disk space available. + +## HuggingFace utilities +The following targets are useful for creating collections and model repositories +on Hugging Face in the the ggml-org. These can be used when preparing a relase +to script the process for new model releases. + +For the following targets a `HF_TOKEN` environment variable is required. + +> 📝 **Note:** Don't forget to logout from Hugging Face after running these +> commands, otherwise you might have issues pulling/cloning repositories as +> the token will still be in use: +> $ huggingface-cli logout +> $ unset HF_TOKEN + +### Create a new Hugging Face Model (model repository) +This will create a new model repsository on Hugging Face with the specified +model name. +```console +(venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model" +Repository ID: danbev/TestModel-GGUF +Repository created: https://huggingface.co/danbev/TestModel-GGUF +``` +Note that we append a `-GGUF` suffix to the model name to ensure a consistent +naming convention for GGUF models. + +An embedding model can be created using the following command: +```console +(venv) $ make hf-create-model-embedding MODEL_NAME='TestEmbeddingModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model" +``` +The only difference is that the model card for an embedding model will be different +with regards to the llama-server command and also how to access/call the embedding +endpoint. + +### Upload a GGUF model to model repository +The following target uploads a model to an existing Hugging Face model repository. +```console +(venv) $ make hf-upload-gguf-to-model MODEL_PATH=dummy-model1.gguf REPO_ID=danbev/TestModel-GGUF +📤 Uploading dummy-model1.gguf to danbev/TestModel-GGUF/dummy-model1.gguf +✅ Upload successful! +🔗 File available at: https://huggingface.co/danbev/TestModel-GGUF/blob/main/dummy-model1.gguf +``` +This command can also be used to update an existing model file in a repository. + +### Create a new Collection +```console +(venv) $ make hf-new-collection NAME=TestCollection DESCRIPTION="Collection for testing scripts" NAMESPACE=danbev +🚀 Creating Hugging Face Collection +Title: TestCollection +Description: Collection for testing scripts +Namespace: danbev +Private: False +✅ Authenticated as: danbev +📚 Creating collection: 'TestCollection'... +✅ Collection created successfully! +📋 Collection slug: danbev/testcollection-68930fcf73eb3fc200b9956d +🔗 Collection URL: https://huggingface.co/collections/danbev/testcollection-68930fcf73eb3fc200b9956d + +🎉 Collection created successfully! +Use this slug to add models: danbev/testcollection-68930fcf73eb3fc200b9956d +``` + +### Add model to a Collection +```console +(venv) $ make hf-add-model-to-collection COLLECTION=danbev/testcollection-68930fcf73eb3fc200b9956d MODEL=danbev/TestModel-GGUF +✅ Authenticated as: danbev +🔍 Checking if model exists: danbev/TestModel-GGUF +✅ Model found: danbev/TestModel-GGUF +📚 Adding model to collection... +✅ Model added to collection successfully! +🔗 Collection URL: https://huggingface.co/collections/danbev/testcollection-68930fcf73eb3fc200b9956d + +🎉 Model added successfully! + +``` diff --git a/llama.cpp/examples/model-conversion/requirements.txt b/llama.cpp/examples/model-conversion/requirements.txt new file mode 100644 index 0000000..229b2ec --- /dev/null +++ b/llama.cpp/examples/model-conversion/requirements.txt @@ -0,0 +1,7 @@ +--extra-index-url https://download.pytorch.org/whl/cpu +torch +torchvision +transformers +huggingface-hub +accelerate +sentence-transformers diff --git a/llama.cpp/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh b/llama.cpp/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh new file mode 100755 index 0000000..2ae4dc7 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/compare-embeddings-logits.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash + +set -e + +MODEL_PATH="${1:-"$MODEL_PATH"}" +MODEL_NAME="${2:-$(basename "$MODEL_PATH")}" + +CONVERTED_MODEL_PATH="${1:-"$CONVERTED_MODEL"}" +CONVERTED_MODEL_NAME="${2:-$(basename "$CONVERTED_MODEL_PATH" ".gguf")}" + +if [ -t 0 ]; then + CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin" +else + # Process piped JSON data and convert to binary (matching logits.cpp format) + TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) + python3 -c " +import json +import sys +import struct + +data = json.load(sys.stdin) + +# Flatten all embeddings completely +flattened = [] +for item in data: + embedding = item['embedding'] + for token_embedding in embedding: + flattened.extend(token_embedding) + +print(f'Total embedding values: {len(flattened)}', file=sys.stderr) + +# Write as binary floats - matches logitc.cpp fwrite format +with open('$TEMP_FILE', 'wb') as f: + for value in flattened: + f.write(struct.pack('f', value)) +" + CPP_EMBEDDINGS="$TEMP_FILE" + trap "rm -f $TEMP_FILE" EXIT +fi + +python scripts/utils/semantic_check.py --model-path $MODEL_PATH \ + --python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \ + --cpp-embeddings $CPP_EMBEDDINGS \ + --prompt "Hello world today" \ + --causal + diff --git a/llama.cpp/examples/model-conversion/scripts/causal/compare-logits.py b/llama.cpp/examples/model-conversion/scripts/causal/compare-logits.py new file mode 100755 index 0000000..83bd14c --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/compare-logits.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import sys +import numpy as np +from pathlib import Path +import os + +# Add utils directory to path for direct script execution +sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) +from common import get_model_name_from_env_path, compare_tokens, exit_with_warning # type: ignore[import-not-found] + +def quick_logits_check(pytorch_file, llamacpp_file): + """Lightweight sanity check before NMSE""" + + try: + pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32) + llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32) + except Exception as e: + print(f"❌ NOK: Failed to load files - {e}") + return False + + # Check shapes match + if pytorch_logits.shape != llamacpp_logits.shape: + print(f"❌ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}") + return False + + # Calculate key metrics + diff = pytorch_logits - llamacpp_logits + abs_diff = np.abs(diff) + max_diff = np.max(abs_diff) + + # Get top 10 predictions from both models + pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1] + llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1] + print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}") + print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}") + print(f"Max absolute difference: {max_diff:.4f}") + + return True + +def main(): + model_path = os.environ.get('MODEL_PATH') + model_name = get_model_name_from_env_path('MODEL_PATH') + data_dir = Path("data") + pytorch_file = data_dir / f"pytorch-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + print(f"Using converted model: {llamacpp_model_name}") + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" + + if not pytorch_file.exists(): + print(f"Error: PyTorch logits file not found: {pytorch_file}") + print("Please run scripts/run-org-model.sh first to generate this file.") + sys.exit(1) + + if not llamacpp_file.exists(): + print(f"Error: llama.cpp logits file not found: {llamacpp_file}") + print("Please run scripts/run-converted-model.sh first to generate this file.") + sys.exit(1) + + print("Checked all required files were found. Proceeding...\n") + + # Verify tokens as they are a prerequisite for logits comparison. + print("🔍 Token Comparison Check") + print("=" * 40) + if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"): + exit_with_warning("\n❌ Token mismatch detected", model_path) + print() + + print("🔍 GGML Model Validation for model ", model_name) + print("=" * 40) + print(f"PyTorch logits : {pytorch_file}") + print(f"llama.cpp logits: {llamacpp_file}") + print() + + success = quick_logits_check(pytorch_file, llamacpp_file) + + # Exit with appropriate code + if success: + print("✅ OK: Lightweight model check successful!") + print(" Ok to proceed with NMSE check...") + sys.exit(0) + else: + exit_with_warning(f"❌ NOK: Top 10 predictions don't match - generation will differ", model_path) + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/causal/convert-model.sh b/llama.cpp/examples/model-conversion/scripts/causal/convert-model.sh new file mode 100755 index 0000000..a5865f6 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/convert-model.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash + +set -e + +# Parse command line arguments +MMPROJ="" +DEBUG="" +while [[ $# -gt 0 ]]; do + case $1 in + --mmproj) + MMPROJ="--mmproj" + shift + ;; + --debug) + DEBUG="1" + shift + ;; + *) + shift + ;; + esac +done + +MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}" +OUTPUT_DIR="${OUTPUT_DIR:-../../models}" +TYPE="${OUTTYPE:-f16}" +METADATA_OVERRIDE="${METADATA_OVERRIDE:-}" +CONVERTED_MODEL="${OUTPUT_DIR}/${MODEL_NAME}.gguf" + +echo "Model path: ${MODEL_PATH}" +echo "Model name: ${MODEL_NAME}" +echo "Data type: ${TYPE}" +echo "Converted model path:: ${CONVERTED_MODEL}" +echo "Metadata override: ${METADATA_OVERRIDE}" + +if [[ -n "$DEBUG" ]]; then + CMD_ARGS=("python" "-m" "pdb") +else + CMD_ARGS=("python") +fi +CMD_ARGS+=("../../convert_hf_to_gguf.py" "--verbose") +CMD_ARGS+=("${MODEL_PATH}") +CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}") +CMD_ARGS+=("--outtype" "${TYPE}") +[[ -n "$METADATA_OVERRIDE" ]] && CMD_ARGS+=("--metadata" "${METADATA_OVERRIDE}") +[[ -n "$MMPROJ" ]] && CMD_ARGS+=("${MMPROJ}") + +"${CMD_ARGS[@]}" + +echo "" +echo "The environment variable CONVERTED_MODEL can be set to this path using:" +echo "export CONVERTED_MODEL=$(realpath ${CONVERTED_MODEL})" +if [[ -n "$MMPROJ" ]]; then + mmproj_file="${OUTPUT_DIR}/mmproj-$(basename "${CONVERTED_MODEL}")" + echo "The mmproj model was created in $(realpath "$mmproj_file")" +fi diff --git a/llama.cpp/examples/model-conversion/scripts/causal/modelcard.template b/llama.cpp/examples/model-conversion/scripts/causal/modelcard.template new file mode 100644 index 0000000..a045950 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/modelcard.template @@ -0,0 +1,13 @@ +--- +base_model: +- {base_model} +--- +# {model_name} GGUF + +Recommended way to run this model: + +```sh +llama-server -hf {namespace}/{model_name}-GGUF +``` + +Then, access http://localhost:8080 diff --git a/llama.cpp/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py b/llama.cpp/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py new file mode 100755 index 0000000..4ab778f --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 + +import argparse +import os +import importlib +import torch +import numpy as np + +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM +from pathlib import Path + +unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + +parser = argparse.ArgumentParser(description='Process model with specified path') +parser.add_argument('--model-path', '-m', help='Path to the model') +args = parser.parse_args() + +model_path = os.environ.get('MODEL_PATH', args.model_path) +if model_path is None: + parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + +config = AutoConfig.from_pretrained(model_path) + +print("Model type: ", config.model_type) +print("Vocab size: ", config.vocab_size) +print("Hidden size: ", config.hidden_size) +print("Number of layers: ", config.num_hidden_layers) +print("BOS token id: ", config.bos_token_id) +print("EOS token id: ", config.eos_token_id) + +print("Loading model and tokenizer using AutoTokenizer:", model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + +if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}ForCausalLM" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(model_path) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + print("Falling back to AutoModelForCausalLM") + model = AutoModelForCausalLM.from_pretrained(model_path) +else: + model = AutoModelForCausalLM.from_pretrained(model_path) +print(f"Model class: {type(model)}") +#print(f"Model file: {type(model).__module__}") + +model_name = os.path.basename(model_path) +print(f"Model name: {model_name}") + +prompt = "Hello world today" +input_ids = tokenizer(prompt, return_tensors="pt").input_ids +print(f"Input tokens: {input_ids}") +print(f"Input text: {repr(prompt)}") +print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") + +with torch.no_grad(): + outputs = model(input_ids, output_hidden_states=True) + + # Extract hidden states from the last layer + # outputs.hidden_states is a tuple of (num_layers + 1) tensors + # Index -1 gets the last layer, shape: [batch_size, seq_len, hidden_size] + last_hidden_states = outputs.hidden_states[-1] + + # Get embeddings for all tokens + token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension + + print(f"Hidden states shape: {last_hidden_states.shape}") + print(f"Token embeddings shape: {token_embeddings.shape}") + print(f"Hidden dimension: {token_embeddings.shape[-1]}") + print(f"Number of tokens: {token_embeddings.shape[0]}") + + # Save raw token embeddings + data_dir = Path("data") + data_dir.mkdir(exist_ok=True) + bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" + txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" + + # Save all token embeddings as binary + print(token_embeddings) + token_embeddings.astype(np.float32).tofile(bin_filename) + + # Save as text for inspection + with open(txt_filename, "w") as f: + for i, embedding in enumerate(token_embeddings): + for j, val in enumerate(embedding): + f.write(f"{i} {j} {val:.6f}\n") + + # Print embeddings per token in the requested format + print("\nToken embeddings:") + tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) + for i, embedding in enumerate(token_embeddings): + # Format: show first few values, ..., then last few values + if len(embedding) > 10: + # Show first 3 and last 3 values with ... in between + first_vals = " ".join(f"{val:8.6f}" for val in embedding[:3]) + last_vals = " ".join(f"{val:8.6f}" for val in embedding[-3:]) + print(f"embedding {i}: {first_vals} ... {last_vals}") + else: + # If embedding is short, show all values + vals = " ".join(f"{val:8.6f}" for val in embedding) + print(f"embedding {i}: {vals}") + + # Also show token info for reference + print(f"\nToken reference:") + for i, token in enumerate(tokens): + print(f" Token {i}: {repr(token)}") + + print(f"Saved bin logits to: {bin_filename}") + print(f"Saved txt logist to: {txt_filename}") diff --git a/llama.cpp/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/llama.cpp/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh new file mode 100755 index 0000000..1b5ff86 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +set -e + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build ${BUILD_DIR} --target llama-debug -j8 + +${BUILD_DIR}/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits diff --git a/llama.cpp/examples/model-conversion/scripts/causal/run-converted-model.sh b/llama.cpp/examples/model-conversion/scripts/causal/run-converted-model.sh new file mode 100755 index 0000000..b684804 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +set -e + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}" +BUILD_DIR="${3:-"$BUILD_DIR"}" + +if [ -z "$MODEL_TESTING_PROMPT" ]; then + MODEL_TESTING_PROMPT="Hello, my name is" +fi + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +echo $CONVERTED_MODEL +echo $MODEL_TESTING_PROMPT + +cmake --build ${BUILD_DIR} --target llama-debug -j8 + +${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits diff --git a/llama.cpp/examples/model-conversion/scripts/causal/run-org-model.py b/llama.cpp/examples/model-conversion/scripts/causal/run-org-model.py new file mode 100755 index 0000000..215f1a9 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/causal/run-org-model.py @@ -0,0 +1,168 @@ +#!/usr/bin/env python3 + +import argparse +import os +import sys +import importlib +import torch +import numpy as np + +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig + +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from utils.common import debug_hook, save_output_data + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Process model with specified path") + parser.add_argument("--model-path", "-m", help="Path to the model") + parser.add_argument("--prompt-file", "-f", help="Optional prompt file", required=False) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug output") + parser.add_argument("--device", "-d", help="Device to use (cpu, cuda, mps, auto)", default="auto") + return parser.parse_args() + +def load_model_and_tokenizer(model_path, device="auto"): + print("Loading model and tokenizer using AutoTokenizer:", model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + multimodal = False + full_config = config + + # Determine device_map based on device argument + if device == "cpu": + device_map = {"": "cpu"} + print("Forcing CPU usage") + elif device == "auto": + device_map = "auto" + else: + device_map = {"": device} + + print("Model type: ", config.model_type) + if "vocab_size" not in config and "text_config" in config: + config = config.text_config + multimodal = True + + print("Vocab size: ", config.vocab_size) + print("Hidden size: ", config.hidden_size) + print("Number of layers: ", config.num_hidden_layers) + print("BOS token id: ", config.bos_token_id) + print("EOS token id: ", config.eos_token_id) + + unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = ( + f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + ) + class_name = f"{unreleased_model_name}ForCausalLM" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=config + ) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + if multimodal: + model = AutoModelForImageTextToText.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=full_config + ) + else: + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=config + ) + + print(f"Model class: {model.__class__.__name__}") + + return model, tokenizer, config + +def enable_torch_debugging(model): + for name, module in model.named_modules(): + if len(list(module.children())) == 0: # only leaf modules + module.register_forward_hook(debug_hook(name)) + +def get_prompt(args): + if args.prompt_file: + with open(args.prompt_file, encoding='utf-8') as f: + return f.read() + elif os.getenv("MODEL_TESTING_PROMPT"): + return os.getenv("MODEL_TESTING_PROMPT") + else: + return "Hello, my name is" + +def main(): + args = parse_arguments() + model_path = os.environ.get("MODEL_PATH", args.model_path) + if model_path is None: + print("Error: Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + sys.exit(1) + + + model, tokenizer, config = load_model_and_tokenizer(model_path, args.device) + + if args.verbose: + enable_torch_debugging(model) + + model_name = os.path.basename(model_path) + + # Iterate over the model parameters (the tensors) and get the first one + # and use it to get the device the model is on. + device = next(model.parameters()).device + prompt = get_prompt(args) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + token_ids = input_ids[0].cpu().tolist() + + print(f"Input tokens: {input_ids}") + print(f"Input text: {repr(prompt)}") + print(f"Tokenized: {tokenizer.convert_ids_to_tokens(input_ids[0])}") + + batch_size = 512 + + with torch.no_grad(): + past = None + outputs = None + for i in range(0, input_ids.size(1), batch_size): + print(f"Processing chunk with tokens {i} to {i + batch_size}") + chunk = input_ids[:, i:i + batch_size] + outputs = model(chunk.to(model.device), past_key_values=past, use_cache=True) + past = outputs.past_key_values + + logits = outputs.logits # type: ignore + + # Extract logits for the last token (next token prediction) + last_logits = logits[0, -1, :].float().cpu().numpy() + + print(f"Logits shape: {logits.shape}") + print(f"Last token logits shape: {last_logits.shape}") + print(f"Vocab size: {len(last_logits)}") + + # Print some sample logits for quick verification + print(f"First 10 logits: {last_logits[:10]}") + print(f"Last 10 logits: {last_logits[-10:]}") + + # Show top 5 predicted tokens + top_indices = np.argsort(last_logits)[-5:][::-1] + print("Top 5 predictions:") + for idx in top_indices: + token = tokenizer.decode([idx]) + print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") + + save_output_data(last_logits, token_ids, prompt, model_name) + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh b/llama.cpp/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh new file mode 100755 index 0000000..984d03e --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/embedding/compare-embeddings-logits.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +set -e + +# Parse command line arguments +MODEL_PATH="" +MODEL_NAME="" +PROMPTS_FILE="" + +# First argument is always model path +if [ $# -gt 0 ] && [[ "$1" != --* ]]; then + MODEL_PATH="$1" + shift +fi + +# Parse remaining arguments +while [[ $# -gt 0 ]]; do + case $1 in + --prompts-file|-pf) + PROMPTS_FILE="$2" + shift 2 + ;; + *) + # If MODEL_NAME not set and this isn't a flag, use as model name + if [ -z "$MODEL_NAME" ] && [[ "$1" != --* ]]; then + MODEL_NAME="$1" + fi + shift + ;; + esac +done + +# Set defaults +MODEL_PATH="${MODEL_PATH:-"$EMBEDDING_MODEL_PATH"}" +MODEL_NAME="${MODEL_NAME:-$(basename "$MODEL_PATH")}" + +CONVERTED_MODEL_PATH="${CONVERTED_EMBEDDING_PATH:-"$CONVERTED_EMBEDDING_MODEL"}" +CONVERTED_MODEL_NAME="${CONVERTED_MODEL_NAME:-$(basename "$CONVERTED_MODEL_PATH" .gguf)}" + +if [ -t 0 ]; then + CPP_EMBEDDINGS="data/llamacpp-${CONVERTED_MODEL_NAME}-embeddings.bin" +else + # Process piped JSON data and convert to binary (matching logits.cpp format) + TEMP_FILE=$(mktemp /tmp/tmp.XXXXXX.binn) + python3 -c " +import json +import sys +import struct + +data = json.load(sys.stdin) + +# Flatten all embeddings completely +flattened = [] +for item in data: + embedding = item['embedding'] + for token_embedding in embedding: + flattened.extend(token_embedding) + +print(f'Total embedding values: {len(flattened)}', file=sys.stderr) + +# Write as binary floats - matches logitc.cpp fwrite format +with open('$TEMP_FILE', 'wb') as f: + for value in flattened: + f.write(struct.pack('f', value)) +" + CPP_EMBEDDINGS="$TEMP_FILE" + trap "rm -f $TEMP_FILE" EXIT +fi + +# Build the semantic_check.py command +SEMANTIC_CMD="python scripts/utils/semantic_check.py --model-path $MODEL_PATH \ + --python-embeddings data/pytorch-${MODEL_NAME}-embeddings.bin \ + --cpp-embeddings $CPP_EMBEDDINGS" + +# Add prompts file if specified, otherwise use default prompt +if [ -n "$PROMPTS_FILE" ]; then + SEMANTIC_CMD="$SEMANTIC_CMD --prompts-file \"$PROMPTS_FILE\"" +else + SEMANTIC_CMD="$SEMANTIC_CMD --prompt \"Hello world today\"" +fi + +# Execute the command +eval $SEMANTIC_CMD + diff --git a/llama.cpp/examples/model-conversion/scripts/embedding/convert-model.sh b/llama.cpp/examples/model-conversion/scripts/embedding/convert-model.sh new file mode 100755 index 0000000..9926350 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/embedding/convert-model.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +set -e + +# Parse command line arguments +SENTENCE_TRANSFORMERS="" +while [[ $# -gt 0 ]]; do + case $1 in + -st|--sentence-transformers) + SENTENCE_TRANSFORMERS="--sentence-transformers-dense-modules" + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + +MODEL_NAME="${MODEL_NAME:-$(basename "$EMBEDDING_MODEL_PATH")}" +OUTPUT_DIR="${OUTPUT_DIR:-../../models}" +TYPE="${OUTTYPE:-f16}" +METADATA_OVERRIDE="${METADATA_OVERRIDE:-}" +CONVERTED_MODEL="${OUTPUT_DIR}/${MODEL_NAME}.gguf" + +echo "Model path: ${EMBEDDING_MODEL_PATH}" +echo "Model name: ${MODEL_NAME}" +echo "Data type: ${TYPE}" +echo "Converted model path:: ${CONVERTED_MODEL}" +python ../../convert_hf_to_gguf.py --verbose \ + ${EMBEDDING_MODEL_PATH} \ + --outfile ${CONVERTED_MODEL} \ + --outtype ${TYPE} \ + ${SENTENCE_TRANSFORMERS} + +echo "" +echo "The environment variable CONVERTED_EMBEDDING MODEL can be set to this path using:" +echo "export CONVERTED_EMBEDDING_MODEL=$(realpath ${CONVERTED_MODEL})" diff --git a/llama.cpp/examples/model-conversion/scripts/embedding/modelcard.template b/llama.cpp/examples/model-conversion/scripts/embedding/modelcard.template new file mode 100644 index 0000000..9e63042 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/embedding/modelcard.template @@ -0,0 +1,48 @@ +--- +base_model: +- {base_model} +--- +# {model_name} GGUF + +Recommended way to run this model: + +```sh +llama-server -hf {namespace}/{model_name}-GGUF --embeddings +``` + +Then the endpoint can be accessed at http://localhost:8080/embedding, for +example using `curl`: +```console +curl --request POST \ + --url http://localhost:8080/embedding \ + --header "Content-Type: application/json" \ + --data '{{"input": "Hello embeddings"}}' \ + --silent +``` + +Alternatively, the `llama-embedding` command line tool can be used: +```sh +llama-embedding -hf {namespace}/{model_name}-GGUF --verbose-prompt -p "Hello embeddings" +``` + +#### embd_normalize +When a model uses pooling, or the pooling method is specified using `--pooling`, +the normalization can be controlled by the `embd_normalize` parameter. + +The default value is `2` which means that the embeddings are normalized using +the Euclidean norm (L2). Other options are: +* -1 No normalization +* 0 Max absolute +* 1 Taxicab +* 2 Euclidean/L2 +* \>2 P-Norm + +This can be passed in the request body to `llama-server`, for example: +```sh + --data '{{"input": "Hello embeddings", "embd_normalize": -1}}' \ +``` + +And for `llama-embedding`, by passing `--embd-normalize <value>`, for example: +```sh +llama-embedding -hf {namespace}/{model_name}-GGUF --embd-normalize -1 -p "Hello embeddings" +``` diff --git a/llama.cpp/examples/model-conversion/scripts/embedding/run-converted-model.sh b/llama.cpp/examples/model-conversion/scripts/embedding/run-converted-model.sh new file mode 100755 index 0000000..ba8a3af --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash + +set -e + +# Parse command line arguments +CONVERTED_MODEL="" +PROMPTS_FILE="" +EMBD_NORMALIZE="2" + +while [[ $# -gt 0 ]]; do + case $1 in + -p|--prompts-file) + PROMPTS_FILE="$2" + shift 2 + ;; + --embd-normalize) + EMBD_NORMALIZE="$2" + shift 2 + ;; + *) + if [ -z "$CONVERTED_MODEL" ]; then + CONVERTED_MODEL="$1" + fi + shift + ;; + esac +done + +# First try command line argument, then environment variable +CONVERTED_MODEL="${CONVERTED_MODEL:-"$CONVERTED_EMBEDDING_MODEL"}" +BUILD_DIR="${BUILD_DIR:-"../../build"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_EMBEDDING_MODEL environment variable" >&2 + exit 1 +fi + +# Read prompt from file or use default +if [ -n "$PROMPTS_FILE" ]; then + if [ ! -f "$PROMPTS_FILE" ]; then + echo "Error: Prompts file '$PROMPTS_FILE' not found" >&2 + exit 1 + fi + PROMPT=$(cat "$PROMPTS_FILE") +else + PROMPT="Hello world today" +fi + +echo $CONVERTED_MODEL + +cmake --build ${BUILD_DIR} --target llama-debug -j8 +${BUILD_DIR}/bin/llama-debug -m "$CONVERTED_MODEL" --embedding -p "$PROMPT" --save-logits --embd-normalize $EMBD_NORMALIZE diff --git a/llama.cpp/examples/model-conversion/scripts/embedding/run-original-model.py b/llama.cpp/examples/model-conversion/scripts/embedding/run-original-model.py new file mode 100755 index 0000000..0802cbc --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/embedding/run-original-model.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +import argparse +import os +import sys +import importlib + +from transformers import AutoTokenizer, AutoConfig, AutoModel +import torch + +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from utils.common import save_output_data + + +def parse_arguments(): + parser = argparse.ArgumentParser(description='Run original embedding model') + parser.add_argument( + '--model-path', + '-m', + help='Path to the model' + ) + parser.add_argument( + '--prompts-file', + '-p', + help='Path to file containing prompts (one per line)' + ) + parser.add_argument( + '--use-sentence-transformers', + action='store_true', + help=('Use SentenceTransformer to apply all numbered layers ' + '(01_Pooling, 02_Dense, 03_Dense, 04_Normalize)') + ) + parser.add_argument( + '--device', + '-d', + help='Device to use (cpu, cuda, mps, auto)', + default='auto' + ) + return parser.parse_args() + + +def load_model_and_tokenizer(model_path, use_sentence_transformers=False, device="auto"): + if device == "cpu": + device_map = {"": "cpu"} + print("Forcing CPU usage") + elif device == "auto": + # On Mac, "auto" device_map can cause issues with accelerate + # So we detect the best device manually + if torch.cuda.is_available(): + device_map = {"": "cuda"} + print("Using CUDA") + elif torch.backends.mps.is_available(): + device_map = {"": "mps"} + print("Using MPS (Apple Metal)") + else: + device_map = {"": "cpu"} + print("Using CPU") + else: + device_map = {"": device} + + if use_sentence_transformers: + from sentence_transformers import SentenceTransformer + print("Using SentenceTransformer to apply all numbered layers") + model = SentenceTransformer(model_path) + tokenizer = model.tokenizer + config = model[0].auto_model.config # type: ignore + else: + tokenizer = AutoTokenizer.from_pretrained(model_path) + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # This can be used to override the sliding window size for manual testing. This + # can be useful to verify the sliding window attention mask in the original model + # and compare it with the converted .gguf model. + if hasattr(config, 'sliding_window'): + original_sliding_window = config.sliding_window + print(f"Modified sliding window: {original_sliding_window} -> {config.sliding_window}") + + unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + print(f"Using unreleased model: {unreleased_model_name}") + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + class_name = f"{unreleased_model_name}Model" + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=config + ) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + sys.exit(1) + else: + model = AutoModel.from_pretrained( + model_path, + device_map=device_map, + offload_folder="offload", + trust_remote_code=True, + config=config + ) + print(f"Model class: {type(model)}") + print(f"Model file: {type(model).__module__}") + + # Verify the model is using the correct sliding window + if hasattr(model.config, 'sliding_window'): # type: ignore + print(f"Model's sliding_window: {model.config.sliding_window}") # type: ignore + else: + print("Model config does not have sliding_window attribute") + + return model, tokenizer, config + + +def get_prompt(args): + if args.prompts_file: + try: + with open(args.prompts_file, 'r', encoding='utf-8') as f: + return f.read().strip() + except FileNotFoundError: + print(f"Error: Prompts file '{args.prompts_file}' not found") + sys.exit(1) + except Exception as e: + print(f"Error reading prompts file: {e}") + sys.exit(1) + else: + return "Hello world today" + + +def main(): + args = parse_arguments() + + model_path = os.environ.get('EMBEDDING_MODEL_PATH', args.model_path) + if model_path is None: + print("Error: Model path must be specified either via --model-path argument " + "or EMBEDDING_MODEL_PATH environment variable") + sys.exit(1) + + # Determine if we should use SentenceTransformer + use_st = ( + args.use_sentence_transformers or os.environ.get('USE_SENTENCE_TRANSFORMERS', '').lower() in ('1', 'true', 'yes') + ) + + model, tokenizer, config = load_model_and_tokenizer(model_path, use_st, args.device) + + # Get the device the model is on + if not use_st: + device = next(model.parameters()).device + else: + # For SentenceTransformer, get device from the underlying model + device = next(model[0].auto_model.parameters()).device # type: ignore + + model_name = os.path.basename(model_path) + + prompt_text = get_prompt(args) + texts = [prompt_text] + + with torch.no_grad(): + if use_st: + embeddings = model.encode(texts, convert_to_numpy=True) + all_embeddings = embeddings # Shape: [batch_size, hidden_size] + + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + + print(f"Embeddings shape (after all SentenceTransformer layers): {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1] if len(all_embeddings.shape) > 1 else all_embeddings.shape[0]}") # type: ignore + else: + # Standard approach: use base model output only + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ) + + tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() + token_strings = tokenizer.convert_ids_to_tokens(tokens) + for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): + print(f"{token_id:6d} -> '{token_str}'") + + # Move inputs to the same device as the model + encoded = {k: v.to(device) for k, v in encoded.items()} + outputs = model(**encoded) + hidden_states = outputs.last_hidden_state # Shape: [batch_size, seq_len, hidden_size] + + all_embeddings = hidden_states[0].float().cpu().numpy() # Shape: [seq_len, hidden_size] + + print(f"Hidden states shape: {hidden_states.shape}") + print(f"All embeddings shape: {all_embeddings.shape}") + print(f"Embedding dimension: {all_embeddings.shape[1]}") + + if len(all_embeddings.shape) == 1: + n_embd = all_embeddings.shape[0] # type: ignore + n_embd_count = 1 + all_embeddings = all_embeddings.reshape(1, -1) + else: + n_embd = all_embeddings.shape[1] # type: ignore + n_embd_count = all_embeddings.shape[0] # type: ignore + + print() + + for j in range(n_embd_count): + embedding = all_embeddings[j] + print(f"embedding {j}: ", end="") + + # Print first 3 values + for i in range(min(3, n_embd)): + print(f"{embedding[i]:9.6f} ", end="") + + print(" ... ", end="") + + # Print last 3 values + for i in range(n_embd - 3, n_embd): + print(f"{embedding[i]:9.6f} ", end="") + + print() # New line + + print() + + flattened_embeddings = all_embeddings.flatten() + print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") + print("") + + save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings") + + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/utils/__init__.py b/llama.cpp/examples/model-conversion/scripts/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/__init__.py diff --git a/llama.cpp/examples/model-conversion/scripts/utils/check-nmse.py b/llama.cpp/examples/model-conversion/scripts/utils/check-nmse.py new file mode 100755 index 0000000..83f63f9 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/check-nmse.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python3 + +import numpy as np +import sys +import os +import argparse +from pathlib import Path +from common import get_model_name_from_env_path # type: ignore[import-not-found] + +def calculate_nmse(reference, test): + mse = np.mean((test - reference) ** 2) + ref_var = np.var(reference) + if ref_var == 0: + nmse = float('inf') if mse > 0 else 0.0 + return mse, mse, ref_var + + nmse = mse / ref_var + + return nmse, mse, ref_var + +def load_logits(file_path): + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + if file_path.suffix == '.npy': + return np.load(file_path) + elif file_path.suffix == '.bin': + return np.fromfile(file_path, dtype=np.float32) + else: + # Try to load as text file + try: + # If it has index format "0: value", extract just values + data = [] + with open(file_path, 'r') as f: + for line in f: + if ':' in line: + # Format: "index: value" + value = float(line.split(':')[1].strip()) + else: + # Just the value + value = float(line.strip()) + data.append(value) + return np.array(data, dtype=np.float32) + except: + return np.loadtxt(file_path, dtype=np.float32) + +def interpret_nmse(nmse): + """Provide interpretation of NMSE value""" + if nmse == 0: + return "Perfect match", "🎉" + elif nmse < 1e-6: + return "Essentially identical", "✅" + elif nmse < 1e-4: + return "Excellent match", "✅" + elif nmse < 1e-3: + return "Very good match", "👍" + elif nmse < 1e-2: + return "Good match", "👍" + elif nmse < 0.1: + return "Acceptable match", "⚠️" + elif nmse < 1.0: + return "Poor match", "❌" + else: + return "Very poor match (worse than noise)", "❌" + +def main(): + parser = argparse.ArgumentParser(description='Validate model logits') + parser.add_argument('-m', '--model-path', required=True, help='Path to the model directory') + args = parser.parse_args() + + model_name = get_model_name_from_env_path('MODEL_PATH') + data_dir = Path("data") + + pytorch_file = data_dir / f"pytorch-{model_name}.bin" + + llamacpp_model_name = get_model_name_from_env_path('CONVERTED_MODEL') + llamacpp_file = data_dir / f"llamacpp-{llamacpp_model_name}.bin" + + print(f"Model name: {model_name}") + print(f"PyTorch logits file: {pytorch_file}") + print(f"llama.cpp logits file: {llamacpp_file}") + + reference_file = pytorch_file + test_file = llamacpp_file + + print("📊 NMSE Check for Model Comparison") + print("=" * 50) + print(f"Reference (ground truth): {reference_file}") + print(f"Test (to evaluate): {test_file}") + print() + + try: + print("Loading reference logits...") + reference = load_logits(reference_file) + print(f" Shape: {reference.shape}, Type: {reference.dtype}") + + print("Loading test logits...") + test = load_logits(test_file) + print(f" Shape: {test.shape}, Type: {test.dtype}") + + # Check shapes match + if reference.shape != test.shape: + print(f"\n❌ Error: Shape mismatch!") + print(f" Reference: {reference.shape}") + print(f" Test: {test.shape}") + sys.exit(1) + + print(f"\n✅ Shapes match: {reference.shape}") + + nmse, mse, ref_var = calculate_nmse(reference, test) + + # Additional metrics + max_abs_error = np.max(np.abs(test - reference)) + mean_abs_error = np.mean(np.abs(test - reference)) + + # Results + print(f"\n📈 METRICS") + print("=" * 30) + print(f"MSE (Mean Squared Error): {mse:.6e}") + print(f"Reference Variance: {ref_var:.6e}") + print(f"NMSE: {nmse:.6e}") + print(f"Max Absolute Error: {max_abs_error:.6f}") + print(f"Mean Absolute Error: {mean_abs_error:.6f}") + + # NMSE in dB (common in signal processing) + if nmse > 0: + nmse_db = 10 * np.log10(nmse) + print(f"NMSE (dB): {nmse_db:.2f} dB") + + # Interpretation + interpretation, emoji = interpret_nmse(nmse) + print(f"\n🎯 INTERPRETATION") + print("=" * 30) + print(f"{emoji} {interpretation}") + + # Detailed guidance + print(f"\n📋 GUIDANCE") + print("=" * 30) + if nmse < 1e-3: + print("✅ EXCELLENT: Your GGML conversion is working very well!") + print(" The differences are negligible for practical use.") + elif nmse < 1e-2: + print("👍 GOOD: Your GGML conversion is working well.") + print(" Small differences are likely due to precision/quantization.") + elif nmse < 0.1: + print("⚠️ ACCEPTABLE: Conversion is working but with some differences.") + print(" Check if you're using quantization (Q4, Q8, etc.)") + print(" Test generation quality to see if it's acceptable.") + else: + print("❌ PROBLEMATIC: Large differences detected.") + print(" Check your conversion process for potential issues.") + print(" Verify you're using the same model weights.") + + # NMSE benchmarks + print(f"\n📚 NMSE BENCHMARKS") + print("=" * 30) + print("< 1e-6: Essentially identical") + print("< 1e-4: Excellent (typical for good conversions)") + print("< 1e-3: Very good") + print("< 1e-2: Good (acceptable for most use cases)") + print("< 0.1: Acceptable (may need verification)") + print("> 1.0: Poor (worse than random)") + + # Exit code based on NMSE + if nmse < 1e-2: + print(f"\n✅ RESULT: PASS (NMSE = {nmse:.2e})") + sys.exit(0) + else: + print(f"\n❌ RESULT: NEEDS REVIEW (NMSE = {nmse:.2e})") + sys.exit(1) + + except Exception as e: + print(f"❌ Error: {e}") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/utils/common.py b/llama.cpp/examples/model-conversion/scripts/utils/common.py new file mode 100644 index 0000000..aa4bab2 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/common.py @@ -0,0 +1,299 @@ +#!/usr/bin/env python3 + +import os +import sys +import torch +import transformers +import json +import textwrap +import numpy as np +from pathlib import Path + + +def get_model_name_from_env_path(env_path_name): + model_path = os.getenv(env_path_name) + if not model_path: + print(f"Error: {env_path_name} environment variable not set") + sys.exit(1) + + if not os.path.exists(model_path): + print(f"Error: Model file not found: {model_path}") + sys.exit(1) + + name = os.path.basename(os.path.normpath(model_path)) + if name.endswith(".gguf"): + name = name[:-5] + + return name + + +def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): + """ + Print a tensor in llama.cpp debug style. + + Supports: + - 2D tensors (seq, hidden) + - 3D tensors (batch, seq, hidden) + - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head + + Shows first and last max_vals of each vector per sequence position. + """ + t = tensor.detach().to(torch.float32).cpu() + + # Determine dimensions + if t.ndim == 3: + _, s, _ = t.shape + elif t.ndim == 2: + _, s = 1, t.shape[0] + t = t.unsqueeze(0) + elif t.ndim == 4: + _, s, _, _ = t.shape + else: + print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") + return + + ten_shape = t.shape + + print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}") + print(" [") + print(" [") + + # Determine indices for first and last sequences + first_indices = list(range(min(s, max_seq))) + last_indices = list(range(max(0, s - max_seq), s)) + + # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq + has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) + + # Combine indices + if has_overlap: + # If there's overlap, just use the combined unique indices + indices = sorted(list(set(first_indices + last_indices))) + separator_index = None + else: + # If no overlap, we'll add a separator between first and last sequences + indices = first_indices + last_indices + separator_index = len(first_indices) + + for i, si in enumerate(indices): + # Add separator if needed + if separator_index is not None and i == separator_index: + print(" ...") + + # Extract appropriate slice + vec = t[0, si] + if vec.ndim == 2: # 4D case: flatten heads × dim_per_head + flat = vec.flatten().tolist() + else: # 2D or 3D case + flat = vec.tolist() + + # First and last slices + first = flat[:max_vals] + last = flat[-max_vals:] if len(flat) >= max_vals else flat + first_str = ", ".join(f"{v:12.4f}" for v in first) + last_str = ", ".join(f"{v:12.4f}" for v in last) + + print(f" [{first_str}, ..., {last_str}]") + + print(" ],") + print(" ]") + print(f" sum = {t.sum().item():.6f}\n") + + +def debug_hook(name): + def fn(_m, input, output): + if isinstance(input, torch.Tensor): + summarize(input, name + "_in") + elif isinstance(input, (tuple, list)) and len(input) > 0 and isinstance(input[0], torch.Tensor): + summarize(input[0], name + "_in") + if isinstance(output, torch.Tensor): + summarize(output, name + "_out") + elif isinstance(output, (tuple, list)) and len(output) > 0 and isinstance(output[0], torch.Tensor): + summarize(output[0], name + "_out") + + return fn + + +def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_pos_emb"): + """ + Apply monkey patch to dump RoPE activations for debugging. + + Args: + model_module_path: Path to the model module (e.g., "transformers.models.apertus.modeling_apertus") + function_name: Name of the RoPE function to patch (default: "apply_rotary_pos_emb") + + Example: + from utils.common import setup_rope_debug + setup_rope_debug("transformers.models.apertus.modeling_apertus") + """ + import importlib + + # Import the module and get the original function + module = importlib.import_module(model_module_path) + orig_rope = getattr(module, function_name) + + # Set torch print options for better debugging + torch.set_printoptions(threshold=float('inf')) + torch.set_printoptions(precision=6, sci_mode=False) + + def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + # log inputs + summarize(q, "RoPE.q_in") + summarize(k, "RoPE.k_in") + + # call original + q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) + + # log outputs + summarize(q_out, "RoPE.q_out") + summarize(k_out, "RoPE.k_out") + + return q_out, k_out + + # Patch it + setattr(module, function_name, debug_rope) + print(f"RoPE debug patching applied to {model_module_path}.{function_name}") + + +def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"): + """ + Save output data (logits/embeddings), tokens, and prompt to files. + + Args: + data: numpy array of floats (logits or embeddings) + tokens: list or array of token IDs + prompt: string containing the input prompt + model_name: name of the model + type_suffix: optional suffix like "-embeddings" (default: "") + output_dir: directory to save files (default: "data") + + Creates the following files in output_dir: + - pytorch-{model_name}{type_suffix}.bin + - pytorch-{model_name}{type_suffix}.txt + - pytorch-{model_name}{type_suffix}-prompt.txt + - pytorch-{model_name}{type_suffix}-tokens.bin + """ + data_dir = Path(output_dir) + data_dir.mkdir(exist_ok=True) + base_path = data_dir / f"pytorch-{model_name}{type_suffix}" + + # Convert and flatten logits/embeddings + data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data) + data = data.flatten() if data.ndim > 1 else data + + # Save logits/embedding files + data.astype(np.float32).tofile(f"{base_path}.bin") + print(f"Data saved to {base_path}.bin") + + with open(f"{base_path}.txt", "w") as f: + f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data)) + print(f"Data saved to {base_path}.txt") + + # Convert and flatten tokens + tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens) + tokens = tokens.flatten() if tokens.ndim > 1 else tokens + + # Save token binary file + tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin") + print(f"Tokens saved to {base_path}-tokens.bin") + + # Save prompt file + with open(f"{base_path}-prompt.txt", "w") as f: + f.write(f"prompt: {prompt}\n") + f.write(f"n_tokens: {len(tokens)}\n") + f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n") + print(f"Prompt saved to {base_path}-prompt.txt") + + +def compare_tokens(original, converted, type_suffix="", output_dir="data"): + data_dir = Path(output_dir) + + # Read tokens from both models + tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin" + tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin" + + if not tokens1_file.exists(): + print(f"Error: Token file not found: {tokens1_file}") + return False + + if not tokens2_file.exists(): + print(f"Error: Token file not found: {tokens2_file}") + return False + + tokens1 = np.fromfile(tokens1_file, dtype=np.int32) + tokens2 = np.fromfile(tokens2_file, dtype=np.int32) + + print(f"\nComparing tokens between:") + print(f" Original : {original} ({len(tokens1)} tokens)") + print(f" Converted: {converted} ({len(tokens2)} tokens)") + + if len(tokens1) != len(tokens2): + print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}") + return False + + if np.array_equal(tokens1, tokens2): + print(f"\n✅ All {len(tokens1)} tokens match!") + return True + + mismatches = np.where(tokens1 != tokens2)[0] + print(f"\n❌ Found {len(mismatches)} mismatched tokens:") + + num_to_show = min(len(mismatches), 10) + for idx in mismatches[:num_to_show]: + print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}") + + if len(mismatches) > num_to_show: + print(f" ... and {len(mismatches) - num_to_show} more mismatches") + + return False + + +def show_version_warning(current_version, model_version): + if not model_version: + return False + + try: + from packaging.version import parse, InvalidVersion + try: + return parse(current_version) < parse(model_version) + except InvalidVersion: + return current_version != model_version + except ImportError: + return current_version != model_version + +def get_model_transformers_version(model_path): + if not model_path: + return None + + config_path = Path(model_path) / "config.json" + if not config_path.is_file(): + return None + + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + return config.get("transformers_version") + except (IOError, json.JSONDecodeError) as e: + print(f"Warning: Could not read or parse {config_path}: {e}", file=sys.stderr) + return None + +def exit_with_warning(message, model_path): + print(message) + + if model_path and transformers is not None: + model_transformers_version = get_model_transformers_version(model_path) + transformers_version = transformers.__version__ + if show_version_warning(transformers_version, model_transformers_version): + warning_message = f""" + ===================================================================== + Verification failure might be due to a transformers version mismatch: + + Current transformers version: {transformers_version} + Model's required version : {model_transformers_version} + + Consider installing the version specified by the model's config: + pip install transformers=={model_transformers_version} + ===================================================================== + """ + print(textwrap.dedent(warning_message)) + sys.exit(1) diff --git a/llama.cpp/examples/model-conversion/scripts/utils/compare_tokens.py b/llama.cpp/examples/model-conversion/scripts/utils/compare_tokens.py new file mode 100755 index 0000000..a286cb5 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/compare_tokens.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import argparse +import sys +from common import compare_tokens # type: ignore + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description='Compare tokens between two models', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 + """ + ) + parser.add_argument( + 'original', + help='Original model name' + ) + parser.add_argument( + 'converted', + help='Converted model name' + ) + parser.add_argument( + '-s', '--suffix', + default='', + help='Type suffix (e.g., "-embeddings")' + ) + parser.add_argument( + '-d', '--data-dir', + default='data', + help='Directory containing token files (default: data)' + ) + parser.add_argument( + '-v', '--verbose', + action='store_true', + help='Print prompts from both models' + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + if args.verbose: + from pathlib import Path + data_dir = Path(args.data_dir) + + prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt" + prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt" + + if prompt1_file.exists(): + print(f"\nOriginal model prompt ({args.original}):") + print(f" {prompt1_file.read_text().strip()}") + + if prompt2_file.exists(): + print(f"\nConverted model prompt ({args.converted}):") + print(f" {prompt2_file.read_text().strip()}") + + print() + + result = compare_tokens( + args.original, + args.converted, + type_suffix=args.suffix, + output_dir=args.data_dir + ) + + # Enable the script to be used in shell scripts so that they can check + # the exit code for success/failure. + sys.exit(0 if result else 1) + + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/utils/create-collection-add-model.sh b/llama.cpp/examples/model-conversion/scripts/utils/create-collection-add-model.sh new file mode 100644 index 0000000..485001b --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/create-collection-add-model.sh @@ -0,0 +1,8 @@ + +#!/usr/bin/env bash + +COLLECTION_SLUG=$(python ./create_collection.py --return-slug) +echo "Created collection: $COLLECTION_SLUG" + +# Use it in the next command +python add_model_to_collection.py "$COLLECTION_SLUG" "username/my-model" diff --git a/llama.cpp/examples/model-conversion/scripts/utils/curl-embedding-server.sh b/llama.cpp/examples/model-conversion/scripts/utils/curl-embedding-server.sh new file mode 100755 index 0000000..7ed69e1 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/curl-embedding-server.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash +curl --request POST \ + --url http://localhost:8080/embedding \ + --header "Content-Type: application/json" \ + --data '{"input": "Hello world today"}' \ + --silent diff --git a/llama.cpp/examples/model-conversion/scripts/utils/hf-add-model-to-collection.py b/llama.cpp/examples/model-conversion/scripts/utils/hf-add-model-to-collection.py new file mode 100755 index 0000000..7e38af3 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/hf-add-model-to-collection.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse +import sys + +def add_model_to_collection(collection_slug, model_id, note=""): + """ + Add a model to an existing collection + + Args: + collection_slug: The slug of the collection (e.g., "username/collection-name-12345") + model_id: The model repository ID (e.g., "username/model-name") + note: Optional note about the model + + Returns: + True if successful, False if failed + """ + + # Initialize API + api = HfApi() + + try: + user_info = api.whoami() + print(f"✅ Authenticated as: {user_info['name']}") + + # Verify the model exists + print(f"🔍 Checking if model exists: {model_id}") + try: + model_info = api.model_info(model_id) + except Exception as e: + print(f"❌ Model not found or not accessible: {model_id}") + print(f"Error: {e}") + return False + + print(f"📚 Adding model to collection...") + api.add_collection_item( + collection_slug=collection_slug, + item_id=model_id, + item_type="model", + note=note + ) + + print(f"✅ Model added to collection successfully!") + print(f"🔗 Collection URL: https://huggingface.co/collections/{collection_slug}") + + return True + + except Exception as e: + print(f"❌ Error adding model to collection: {e}") + return False + +def main(): + # This script requires that the environment variable HF_TOKEN is set with your + # Hugging Face API token. + api = HfApi() + + parser = argparse.ArgumentParser(description='Add model to a Huggingface Collection') + parser.add_argument('--collection', '-c', help='The collection slug username/collection-hash', required=True) + parser.add_argument('--model', '-m', help='The model to add to the Collection', required=True) + parser.add_argument('--note', '-n', help='An optional note/description', required=False) + args = parser.parse_args() + + collection = args.collection + model = args.model + note = args.note + + success = add_model_to_collection( + collection_slug=collection, + model_id=model, + note=note + ) + + if success: + print("\n🎉 Model added successfully!") + else: + print("\n❌ Failed to add model to collection") + sys.exit(1) +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/utils/hf-create-collection.py b/llama.cpp/examples/model-conversion/scripts/utils/hf-create-collection.py new file mode 100755 index 0000000..e0fa60a --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/hf-create-collection.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse +import os +import sys + + +def create_collection(title, description, private=False, namespace=None, return_slug=False): + """ + Create a new collection on Hugging Face + + Args: + title: Collection title + description: Collection description + private: Whether the collection should be private (default: False) + namespace: Optional namespace (defaults to your username) + + Returns: + Collection object if successful, None if failed + """ + + # Check if HF_TOKEN is available + token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") + if not token: + print("❌ No HF_TOKEN or HUGGINGFACE_HUB_TOKEN found in environment variables") + print("Please set your Hugging Face token as an environment variable") + return None + + # Initialize API + api = HfApi() + + try: + # Test authentication first + user_info = api.whoami() + if not return_slug: + print(f"✅ Authenticated as: {user_info['name']}") + + # Create the collection + if not return_slug: + print(f"📚 Creating collection: '{title}'...") + collection = api.create_collection( + title=title, + description=description, + private=private, + namespace=namespace + ) + + if not return_slug: + print(f"✅ Collection created successfully!") + print(f"📋 Collection slug: {collection.slug}") + print(f"🔗 Collection URL: https://huggingface.co/collections/{collection.slug}") + + return collection + + except Exception as e: + print(f"❌ Error creating collection: {e}") + return None + +def main(): + # This script requires that the environment variable HF_TOKEN is set with your + # Hugging Face API token. + api = HfApi() + + parser = argparse.ArgumentParser(description='Create a Huggingface Collection') + parser.add_argument('--name', '-n', help='The name/title of the Collection', required=True) + parser.add_argument('--description', '-d', help='The description for the Collection', required=True) + parser.add_argument('--namespace', '-ns', help='The namespace to add the Collection to', required=True) + parser.add_argument('--private', '-p', help='Create a private Collection', action='store_true') # Fixed + parser.add_argument('--return-slug', '-s', help='Only output the collection slug', action='store_true') # Fixed + + args = parser.parse_args() + + name = args.name + description = args.description + private = args.private + namespace = args.namespace + return_slug = args.return_slug + + if not return_slug: + print("🚀 Creating Hugging Face Collection") + print(f"Title: {name}") + print(f"Description: {description}") + print(f"Namespace: {namespace}") + print(f"Private: {private}") + + collection = create_collection( + title=name, + description=description, + private=private, + namespace=namespace, + return_slug=return_slug + ) + + if collection: + if return_slug: + print(collection.slug) + else: + print("\n🎉 Collection created successfully!") + print(f"Use this slug to add models: {collection.slug}") + else: + print("\n❌ Failed to create collection") + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/utils/hf-create-model.py b/llama.cpp/examples/model-conversion/scripts/utils/hf-create-model.py new file mode 100755 index 0000000..ea99bd8 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/hf-create-model.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse + +# This script requires that the environment variable HF_TOKEN is set with your +# Hugging Face API token. +api = HfApi() + +def load_template_and_substitute(template_path, **kwargs): + try: + with open(template_path, 'r', encoding='utf-8') as f: + template_content = f.read() + + return template_content.format(**kwargs) + except FileNotFoundError: + print(f"Template file '{template_path}' not found!") + return None + except KeyError as e: + print(f"Missing template variable: {e}") + return None + +parser = argparse.ArgumentParser(description='Create a new Hugging Face model repository') +parser.add_argument('--model-name', '-m', help='Name for the model', required=True) +parser.add_argument('--namespace', '-ns', help='Namespace to add the model to', required=True) +parser.add_argument('--org-base-model', '-b', help='Original Base model name', default="") +parser.add_argument('--no-card', action='store_true', help='Skip creating model card') +parser.add_argument('--private', '-p', action='store_true', help='Create private model') +parser.add_argument('--embedding', '-e', action='store_true', help='Use embedding model card template') +parser.add_argument('--dry-run', '-d', action='store_true', help='Print repository info and template without creating repository') + +args = parser.parse_args() + +repo_id = f"{args.namespace}/{args.model_name}-GGUF" +print("Repository ID: ", repo_id) + +repo_url = None +if not args.dry_run: + repo_url = api.create_repo( + repo_id=repo_id, + repo_type="model", + private=args.private, + exist_ok=False + ) + +if not args.no_card: + if args.embedding: + template_path = "scripts/embedding/modelcard.template" + else: + template_path = "scripts/causal/modelcard.template" + + print("Template path: ", template_path) + + model_card_content = load_template_and_substitute( + template_path, + model_name=args.model_name, + namespace=args.namespace, + base_model=args.org_base_model, + ) + + if args.dry_run: + print("\nTemplate Content:\n") + print(model_card_content) + else: + if model_card_content: + api.upload_file( + path_or_fileobj=model_card_content.encode('utf-8'), + path_in_repo="README.md", + repo_id=repo_id + ) + print("Model card created successfully.") + else: + print("Failed to create model card.") + +if not args.dry_run and repo_url: + print(f"Repository created: {repo_url}") + + diff --git a/llama.cpp/examples/model-conversion/scripts/utils/hf-upload-gguf-model.py b/llama.cpp/examples/model-conversion/scripts/utils/hf-upload-gguf-model.py new file mode 100755 index 0000000..15ccb11 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/hf-upload-gguf-model.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 + +from huggingface_hub import HfApi +import argparse +import os + +def upload_gguf_file(local_file_path, repo_id, filename_in_repo=None): + """ + Upload a GGUF file to a Hugging Face model repository + + Args: + local_file_path: Path to your local GGUF file + repo_id: Your repository ID (e.g., "username/model-name") + filename_in_repo: Optional custom name for the file in the repo + """ + + if not os.path.exists(local_file_path): + print(f"❌ File not found: {local_file_path}") + return False + + if filename_in_repo is None: + filename_in_repo = os.path.basename(local_file_path) + + if filename_in_repo is None or filename_in_repo == "": + filename_in_repo = os.path.basename(local_file_path) + + print(f"📤 Uploading {local_file_path} to {repo_id}/{filename_in_repo}") + + api = HfApi() + + try: + api.upload_file( + path_or_fileobj=local_file_path, + path_in_repo=filename_in_repo, + repo_id=repo_id, + repo_type="model", + commit_message=f"Upload {filename_in_repo}" + ) + + print("✅ Upload successful!") + print(f"🔗 File available at: https://huggingface.co/{repo_id}/blob/main/{filename_in_repo}") + return True + + except Exception as e: + print(f"❌ Upload failed: {e}") + return False + +# This script requires that the environment variable HF_TOKEN is set with your +# Hugging Face API token. +api = HfApi() + +parser = argparse.ArgumentParser(description='Upload a GGUF model to a Huggingface model repository') +parser.add_argument('--gguf-model-path', '-m', help='The GGUF model file to upload', required=True) +parser.add_argument('--repo-id', '-r', help='The repository to upload to', required=True) +parser.add_argument('--name', '-o', help='The name in the model repository', required=False) +args = parser.parse_args() + +upload_gguf_file(args.gguf_model_path, args.repo_id, args.name) diff --git a/llama.cpp/examples/model-conversion/scripts/utils/inspect-converted-model.sh b/llama.cpp/examples/model-conversion/scripts/utils/inspect-converted-model.sh new file mode 100755 index 0000000..32d8482 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/inspect-converted-model.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash + +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +../../gguf-py/gguf/scripts/gguf_dump.py $CONVERTED_MODEL diff --git a/llama.cpp/examples/model-conversion/scripts/utils/inspect-org-model.py b/llama.cpp/examples/model-conversion/scripts/utils/inspect-org-model.py new file mode 100755 index 0000000..bc6f45a --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/inspect-org-model.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +import argparse +import os +import json +from safetensors import safe_open +from collections import defaultdict + +parser = argparse.ArgumentParser(description='Process model with specified path') +parser.add_argument('--model-path', '-m', help='Path to the model') +args = parser.parse_args() + +model_path = os.environ.get('MODEL_PATH', args.model_path) +if model_path is None: + parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") + +# Check if there's an index file (multi-file model) +index_path = os.path.join(model_path, "model.safetensors.index.json") +single_file_path = os.path.join(model_path, "model.safetensors") + +if os.path.exists(index_path): + # Multi-file model + print("Multi-file model detected") + + with open(index_path, 'r') as f: + index_data = json.load(f) + + # Get the weight map (tensor_name -> file_name) + weight_map = index_data.get("weight_map", {}) + + # Group tensors by file for efficient processing + file_tensors = defaultdict(list) + for tensor_name, file_name in weight_map.items(): + file_tensors[file_name].append(tensor_name) + + print("Tensors in model:") + + # Process each shard file + for file_name, tensor_names in file_tensors.items(): + file_path = os.path.join(model_path, file_name) + print(f"\n--- From {file_name} ---") + + with safe_open(file_path, framework="pt") as f: + for tensor_name in sorted(tensor_names): + tensor = f.get_tensor(tensor_name) + print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}") + +elif os.path.exists(single_file_path): + # Single file model (original behavior) + print("Single-file model detected") + + with safe_open(single_file_path, framework="pt") as f: + keys = f.keys() + print("Tensors in model:") + for key in sorted(keys): + tensor = f.get_tensor(key) + print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}") + +else: + print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}") + print("Available files:") + if os.path.exists(model_path): + for item in sorted(os.listdir(model_path)): + print(f" {item}") + else: + print(f" Directory {model_path} does not exist") + exit(1) diff --git a/llama.cpp/examples/model-conversion/scripts/utils/perplexity-gen.sh b/llama.cpp/examples/model-conversion/scripts/utils/perplexity-gen.sh new file mode 100755 index 0000000..ef4b650 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/perplexity-gen.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash + +set -e + +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +# Check if data/wikitext-2-raw directory exists +if [ ! -d "ppl/wikitext-2-raw" ]; then + echo "ppl/wikitext-2-raw directory does not exist. Downloading..." >&2 + mkdir -p ppl + pushd ppl + ./../../../scripts/get-wikitext-2.sh + popd +fi + +mkdir -p ppl +OUTPUTFILE="ppl/$(basename $CONVERTED_MODEL).kld" +echo "Model: $CONVERTED_MODEL" + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build $BUILD_DIR --target llama-perplexity -j8 + +${BUILD_DIR}/bin/llama-perplexity -m $CONVERTED_MODEL \ + -f ppl/wikitext-2-raw/wiki.test.raw \ + --kl-divergence-base $OUTPUTFILE + +echo "Generated logits in $OUTPUTFILE" + diff --git a/llama.cpp/examples/model-conversion/scripts/utils/perplexity-run-simple.sh b/llama.cpp/examples/model-conversion/scripts/utils/perplexity-run-simple.sh new file mode 100755 index 0000000..20ee965 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/perplexity-run-simple.sh @@ -0,0 +1,32 @@ +#!/usr/bin/env bash + +set -e + +QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" + +if [ -z "$QUANTIZED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. QUANTIZED_MODEL environment variable" >&2 + exit 1 +fi + +# Check if data/wikitext-2-raw directory exists +if [ ! -d "ppl/wikitext-2-raw" ]; then + echo "ppl/wikitext-2-raw directory does not exist. Downloading..." >&2 + mkdir -p ppl + pushd ppl + ./../../../scripts/get-wikitext-2.sh + popd +fi + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build $BUILD_DIR --target llama-perplexity -j8 + +${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL -f ppl/wikitext-2-raw/wiki.test.raw + + diff --git a/llama.cpp/examples/model-conversion/scripts/utils/perplexity-run.sh b/llama.cpp/examples/model-conversion/scripts/utils/perplexity-run.sh new file mode 100755 index 0000000..c11f32c --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/perplexity-run.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +set -e + +QUANTIZED_MODEL="${1:-"$QUANTIZED_MODEL"}" +LOGITS_FILE="${2:-"$LOGITS_FILE"}" +BUILD_DIR="${3:-"$BUILD_DIR"}" + +if [ -z "$QUANTIZED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. QUANTIZED_MODEL environment variable" >&2 + exit 1 +fi + +if [ ! -f ${LOGITS_FILE} ]; then + echo "Error: logits file '${LOGITS_FILE} was not found" + echo "Did you run the perplexity-gen.sh script?" + exit 1 +fi + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +echo "Model: $QUANTIZED_MODEL" +echo "Data file: $LOGITS_FILE" + +cmake --build $BUILD_DIR --target llama-perplexity -j8 + +${BUILD_DIR}/bin/llama-perplexity -m $QUANTIZED_MODEL \ + --kl-divergence-base $LOGITS_FILE \ + --kl-divergence diff --git a/llama.cpp/examples/model-conversion/scripts/utils/quantize.sh b/llama.cpp/examples/model-conversion/scripts/utils/quantize.sh new file mode 100755 index 0000000..4c21a13 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/quantize.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +set -e + +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +QUANTIZED_TYPE="${2:-"$QUANTIZED_TYPE"}" +TOKEN_EMBD_TYPE="${3:-"${TOKEN_EMBD_TYPE}"}" +OUTPUT_TYPE="${4:-"${OUTPUT_TYPE}"}" +BUILD_DIR="${5:-"$BUILD_DIR"}" +QUANTIZED_MODEL=$CONVERTED_MODEL + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +if [ -z "$QUANTIZED_TYPE" ]; then + echo "Error: QUANTIZED_TYPE is required" >&2 + exit 1 +fi + +echo $CONVERTED_MODEL + +# Process the quantized model filename +if [[ "$QUANTIZED_MODEL" == *.gguf ]]; then + # Remove .gguf suffix, add quantized type, then add .gguf back + BASE_NAME="${QUANTIZED_MODEL%.gguf}" + QUANTIZED_MODEL="${BASE_NAME}-${QUANTIZED_TYPE}.gguf" +else + echo "Error: QUANTIZED_MODEL must end with .gguf extension" >&2 + exit 1 +fi + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +cmake --build $BUILD_DIR --target llama-quantize -j8 + +echo $TOKEN_EMBD_TYPE +echo $OUTPUT_TYPE + +CMD_ARGS=("${BUILD_DIR}/bin/llama-quantize") +[[ -n "$TOKEN_EMBD_TYPE" ]] && CMD_ARGS+=("--token-embedding-type" "$TOKEN_EMBD_TYPE") +[[ -n "$OUTPUT_TYPE" ]] && CMD_ARGS+=("--output-tensor-type" "$OUTPUT_TYPE") +CMD_ARGS+=("$CONVERTED_MODEL" "$QUANTIZED_MODEL" "$QUANTIZED_TYPE") + +"${CMD_ARGS[@]}" + +echo "Quantized model saved to: $QUANTIZED_MODEL" diff --git a/llama.cpp/examples/model-conversion/scripts/utils/run-embedding-server.sh b/llama.cpp/examples/model-conversion/scripts/utils/run-embedding-server.sh new file mode 100755 index 0000000..9f5fc2c --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/run-embedding-server.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +set -e +# +# First try command line argument, then environment variable, then file +CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" +BUILD_DIR="${2:-"$BUILD_DIR"}" + +# Final check if we have a model path +if [ -z "$CONVERTED_MODEL" ]; then + echo "Error: Model path must be provided either as:" >&2 + echo " 1. Command line argument" >&2 + echo " 2. CONVERTED_MODEL environment variable" >&2 + exit 1 +fi + +if [ -z "$BUILD_DIR" ]; then + BUILD_DIR="../../build" +fi + +echo $CONVERTED_MODEL + +cmake --build $BUILD_DIR --target llama-server + +${BUILD_DIR}/bin/llama-server -m $CONVERTED_MODEL \ + --embedding \ + --pooling none diff --git a/llama.cpp/examples/model-conversion/scripts/utils/semantic_check.py b/llama.cpp/examples/model-conversion/scripts/utils/semantic_check.py new file mode 100644 index 0000000..73e20ea --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/semantic_check.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 + +import numpy as np +import argparse +import os +import importlib +from pathlib import Path + +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel +from common import compare_tokens, exit_with_warning # type: ignore[import-not-found] + +unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') + +def cosine_similarity(a, b=None): + a = np.asarray(a) + if b is None: + b = a + else: + b = np.asarray(b) + + if a.ndim == 1: + a = a.reshape(1, -1) + if b.ndim == 1: + b = b.reshape(1, -1) + + a_norms = np.linalg.norm(a, axis=1, keepdims=True) + b_norms = np.linalg.norm(b, axis=1, keepdims=True) + + a_norms = np.where(a_norms == 0, 1e-8, a_norms) + b_norms = np.where(b_norms == 0, 1e-8, b_norms) + + a_normalized = a / a_norms + b_normalized = b / b_norms + + # Compute cosine similarity + return np.dot(a_normalized, b_normalized.T) + +def load_embeddings_from_file(filename, n_tokens, n_embd): + embeddings = np.fromfile(filename, dtype=np.float32) + # Check if this is pooled (single embedding) or per-token embeddings + if len(embeddings) == n_embd: + return embeddings.reshape(1, n_embd) + else: + return embeddings.reshape(n_tokens, n_embd) + +def test_single_prompt_similarity(python_emb, cpp_emb, tokens, prompt): + np.set_printoptions(suppress=True, precision=6) + print("pytorch embeddings:"); + print(python_emb) + print("llama.cpp embeddings:"); + print(cpp_emb) + print(f"\n=== Prompt: '{prompt}' ===") + print(f"Tokens: {tokens}") + print(f"Embeddings shape: Python {python_emb.shape}, llama.cpp {cpp_emb.shape}") + + n_tokens = len(tokens) + is_pooled = python_emb.shape[0] == 1 + + if is_pooled: + print(f"\n[Pooled Embeddings Mode - comparing single sentence embeddings]") + + # 1. Direct embedding comparison for pooled embeddings + print(f"\n1. Raw Embedding Magnitude Comparison:") + py_mag = np.linalg.norm(python_emb[0]) + cpp_mag = np.linalg.norm(cpp_emb[0]) + ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') + print(f" Pooled embedding: Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cross-model similarity for pooled embeddings + print(f"\n2. Cross-Model Pooled Embedding Similarity:") + sim = cosine_similarity([python_emb[0]], [cpp_emb[0]])[0][0] + print(f" Cosine similarity: {sim:.6f}") + + return { + 'cross_model_similarities': [sim], + 'similarity_matrix_diff': np.array([[0.0]]), + 'max_diff': 0.0, + 'mean_diff': 0.0, + 'rms_diff': 0.0 + } + else: + # Original per-token comparison logic + # 1. Direct embedding comparison + print(f"\n1. Raw Embedding Magnitude Comparison:") + # Check if the distance of each token embedding from the origin and compare + # if the vectors are on the same "sphere". This does not tell us about + # direction (meaning of the token embedding), just magnitude. + for i in range(n_tokens): + py_mag = np.linalg.norm(python_emb[i]) # calculate standard euclidean norm for Python embeddings + cpp_mag = np.linalg.norm(cpp_emb[i]) # calculate standard euclidean norm for llama.cpp embeddings + ratio = py_mag / cpp_mag if cpp_mag > 0 else float('inf') + print(f" Token {i} ({tokens[i]}): Python={py_mag:.3f}, llama.cpp={cpp_mag:.3f}, ratio={ratio:.3f}") + + # 2. Cosine similarity between tokens within each model + # Here we check the direction of token embeddings to see if the have the + # same meaning (similarity). This is done by calculating cosine similarity + # of a pair of token embeddings within each model. + print(f"\n2. Within-Model Token Similarities:") + print(" Python model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([python_emb[i]], [python_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + print(" llama.cpp model:") + for i in range(n_tokens): + for j in range(i+1, n_tokens): + sim = cosine_similarity([cpp_emb[i]], [cpp_emb[j]])[0][0] + print(f" {tokens[i]} ↔ {tokens[j]}: {sim:.4f}") + + # 3. Cross-model similarity (same token position) + print(f"\n3. Cross-Model Same-Token Similarities:") + for i in range(n_tokens): + sim = cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] + print(f" Token {i} ({tokens[i]}): {sim:.4f}") + + # 4. Similarity matrix comparison + print(f"\n4. Similarity Matrix Differences:") + py_sim_matrix = cosine_similarity(python_emb) + cpp_sim_matrix = cosine_similarity(cpp_emb) + diff_matrix = np.abs(py_sim_matrix - cpp_sim_matrix) + + print(f" Max difference: {np.max(diff_matrix):.4f}") + print(f" Mean difference: {np.mean(diff_matrix):.4f}") + print(f" RMS difference: {np.sqrt(np.mean(diff_matrix**2)):.4f}") + + return { + 'cross_model_similarities': [cosine_similarity([python_emb[i]], [cpp_emb[i]])[0][0] for i in range(n_tokens)], + 'similarity_matrix_diff': diff_matrix, + 'max_diff': np.max(diff_matrix), + 'mean_diff': np.mean(diff_matrix), + 'rms_diff': np.sqrt(np.mean(diff_matrix**2)) + } + +def read_prompt_from_file(file_path): + try: + with open(file_path, 'r', encoding='utf-8') as f: + return f.read().strip() + except FileNotFoundError: + print(f"Error: Prompts file '{file_path}' not found") + exit(1) + except Exception as e: + print(f"Error reading prompts file: {e}") + exit(1) + +def main(): + parser = argparse.ArgumentParser(description='Test semantic similarity between Python and llama.cpp embeddings') + parser.add_argument('--model-path', '-m', required=True, help='Path to the original Python model') + parser.add_argument('--python-embeddings', '-pe', help='Path to pytorch embeddings "logits" binary file') + parser.add_argument('--cpp-embeddings', '-ce', help='Path to llama.cpp embeddings "logits" binary file') + parser.add_argument('--causal', '-c', default=False, help='if the model is causal (default: false)', action='store_true') + parser.add_argument('--prompt', '-p', default='Hello world today', help='Test prompt') + parser.add_argument('--prompts-file', '-pf', help='Path to file containing prompts') + + args = parser.parse_args() + + if args.prompts_file: + prompt = read_prompt_from_file(args.prompts_file) + else: + prompt = args.prompt + + python_emb_path = Path(args.python_embeddings) + cpp_emb_path = Path(args.cpp_embeddings) + + # Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name") + python_model_name = python_emb_path.stem.replace("-embeddings", "") + cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "") + + print("Semantic Similarity Test Between Python and llama.cpp Embedding Models") + print("=" * 70) + + # First verify tokens match before comparing embeddings + print("\n🔍 Token Comparison Check") + print("=" * 70) + data_dir = python_emb_path.parent + if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)): + exit_with_warning("\n❌ Token mismatch detected", args.model_path) + print() + + # Single prompt detailed comparison + print(f"\nTesting with prompt: '{prompt}'") + + # Load the python model to get configuration information and also to load the tokenizer. + print("Loading model and tokenizer using AutoTokenizer:", args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) + + if unreleased_model_name: + model_name_lower = unreleased_model_name.lower() + unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" + if args.causal: + class_name = f"{unreleased_model_name}ForCausalLM" + else: + class_name = f"{unreleased_model_name}Model" + print(f"Model class: {class_name}") + print(f"Importing unreleased model module: {unreleased_module_path}") + + try: + model_class = getattr(importlib.import_module(unreleased_module_path), class_name) + model = model_class.from_pretrained(args.model_path) + except (ImportError, AttributeError) as e: + print(f"Failed to import or load model: {e}") + exit(1) + else: + if args.causal: + model = AutoModelForCausalLM.from_pretrained(args.model_path, trust_remote_code=True) + else: + model = AutoModel.from_pretrained(args.model_path, trust_remote_code=True) + + encoded = tokenizer(prompt, return_tensors="pt") + tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0]) + n_tokens = len(tokens) + print(f"n_tokens: {n_tokens}"); + print(f"hidden_size: {model.config.hidden_size}") + + # Load binary embeddings from data directory. + llamacpp_embeddings = load_embeddings_from_file(args.cpp_embeddings, n_tokens, model.config.hidden_size) + python_embeddings = load_embeddings_from_file(args.python_embeddings, n_tokens, model.config.hidden_size) + + # Run comparison + results = test_single_prompt_similarity(python_embeddings, llamacpp_embeddings, tokens, prompt) + + # Summary + print(f"\n=== SUMMARY ===") + avg_cross_sim = np.mean(results['cross_model_similarities']) + print(f"Average cross-model similarity: {avg_cross_sim:.4f}") + print(f"Similarity matrix RMS difference: {results['rms_diff']:.4f}") + + # Quality assessment + if avg_cross_sim > 0.95: + print("✅ EXCELLENT: Models are highly similar") + elif avg_cross_sim > 0.90: + print("✅ VERY GOOD: Models are very similar") + elif avg_cross_sim > 0.80: + print("⚠️ GOOD: Models are reasonably similar") + elif avg_cross_sim > 0.70: + print("⚠️ FAIR: Models have some differences") + else: + exit_with_warning("❌ POOR: Models are significantly different", args.model_path) + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/model-conversion/scripts/utils/tensor-info.py b/llama.cpp/examples/model-conversion/scripts/utils/tensor-info.py new file mode 100755 index 0000000..12a3430 --- /dev/null +++ b/llama.cpp/examples/model-conversion/scripts/utils/tensor-info.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +import argparse +import json +import os +import re +import sys +from pathlib import Path +from typing import Optional +from safetensors import safe_open + + +MODEL_SAFETENSORS_FILE = "model.safetensors" +MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json" + + +def get_weight_map(model_path: Path) -> Optional[dict[str, str]]: + index_file = model_path / MODEL_SAFETENSORS_INDEX + + if index_file.exists(): + with open(index_file, 'r') as f: + index = json.load(f) + return index.get("weight_map", {}) + + return None + + +def get_all_tensor_names(model_path: Path) -> list[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return list(weight_map.keys()) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + try: + with safe_open(single_file, framework="pt", device="cpu") as f: + return list(f.keys()) + except Exception as e: + print(f"Error reading {single_file}: {e}") + sys.exit(1) + + print(f"Error: No safetensors files found in {model_path}") + sys.exit(1) + + +def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return weight_map.get(tensor_name) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + return single_file.name + + return None + + +def normalize_tensor_name(tensor_name: str) -> str: + normalized = re.sub(r'\.\d+\.', '.#.', tensor_name) + normalized = re.sub(r'\.\d+$', '.#', normalized) + return normalized + + +def list_all_tensors(model_path: Path, unique: bool = False): + tensor_names = get_all_tensor_names(model_path) + + if unique: + seen = set() + for tensor_name in sorted(tensor_names): + normalized = normalize_tensor_name(tensor_name) + if normalized not in seen: + seen.add(normalized) + print(normalized) + else: + for tensor_name in sorted(tensor_names): + print(tensor_name) + + +def print_tensor_info(model_path: Path, tensor_name: str): + tensor_file = find_tensor_file(model_path, tensor_name) + + if tensor_file is None: + print(f"Error: Could not find tensor '{tensor_name}' in model index") + print(f"Model path: {model_path}") + sys.exit(1) + + file_path = model_path / tensor_file + + try: + with safe_open(file_path, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + tensor_slice = f.get_slice(tensor_name) + shape = tensor_slice.get_shape() + print(f"Tensor: {tensor_name}") + print(f"File: {tensor_file}") + print(f"Shape: {shape}") + else: + print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") + sys.exit(1) + + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + sys.exit(1) + except Exception as e: + print(f"An error occurred: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Print tensor information from a safetensors model" + ) + parser.add_argument( + "tensor_name", + nargs="?", # optional (if --list is used for example) + help="Name of the tensor to inspect" + ) + parser.add_argument( + "-m", "--model-path", + type=Path, + help="Path to the model directory (default: MODEL_PATH environment variable)" + ) + parser.add_argument( + "-l", "--list", + action="store_true", + help="List unique tensor patterns in the model (layer numbers replaced with #)" + ) + + args = parser.parse_args() + + model_path = args.model_path + if model_path is None: + model_path_str = os.environ.get("MODEL_PATH") + if model_path_str is None: + print("Error: --model-path not provided and MODEL_PATH environment variable not set") + sys.exit(1) + model_path = Path(model_path_str) + + if not model_path.exists(): + print(f"Error: Model path does not exist: {model_path}") + sys.exit(1) + + if not model_path.is_dir(): + print(f"Error: Model path is not a directory: {model_path}") + sys.exit(1) + + if args.list: + list_all_tensors(model_path, unique=True) + else: + if args.tensor_name is None: + print("Error: tensor_name is required when not using --list") + sys.exit(1) + print_tensor_info(model_path, args.tensor_name) + + +if __name__ == "__main__": + main() diff --git a/llama.cpp/examples/parallel/CMakeLists.txt b/llama.cpp/examples/parallel/CMakeLists.txt new file mode 100644 index 0000000..847e916 --- /dev/null +++ b/llama.cpp/examples/parallel/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-parallel) +add_executable(${TARGET} parallel.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/parallel/README.md b/llama.cpp/examples/parallel/README.md new file mode 100644 index 0000000..2468a30 --- /dev/null +++ b/llama.cpp/examples/parallel/README.md @@ -0,0 +1,14 @@ +# llama.cpp/example/parallel + +Simplified simulation of serving incoming requests in parallel + +## Example + +Generate 128 client requests (`-ns 128`), simulating 8 concurrent clients (`-np 8`). The system prompt is shared (`-pps`), meaning that it is computed once at the start. The client requests consist of up to 10 junk questions (`--junk 10`) followed by the actual question. + +```bash +llama-parallel -m model.gguf -np 8 -ns 128 --top-k 1 -pps --junk 10 -c 16384 +``` + +> [!NOTE] +> It's recommended to use base models with this example. Instruction tuned models might not be able to properly follow the custom chat template specified here, so the results might not be as expected. diff --git a/llama.cpp/examples/parallel/parallel.cpp b/llama.cpp/examples/parallel/parallel.cpp new file mode 100644 index 0000000..c92173a --- /dev/null +++ b/llama.cpp/examples/parallel/parallel.cpp @@ -0,0 +1,517 @@ +// A basic application simulating a server with multiple clients. +// The clients submit requests to the server and they are processed in parallel. + +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" + +#include <cmath> +#include <cstdio> +#include <string> +#include <vector> +#include <ctime> +#include <algorithm> + +// trim whitespace from the beginning and end of a string +static std::string trim(const std::string & str) { + size_t start = 0; + size_t end = str.size(); + + while (start < end && isspace(str[start])) { + start += 1; + } + + while (end > start && isspace(str[end - 1])) { + end -= 1; + } + + return str.substr(start, end - start); +} + +static std::string k_system = +R"(Transcript of a never ending dialog, where the User interacts with an Assistant. +The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User: +Recommend a nice restaurant in the area. +Assistant: +I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays. +User: +Who is Richard Feynman? +Assistant: +Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?". +)"; + +static std::vector<std::string> k_questions = { + "What is the tallest mountain in the world?", + "Who was the first person to win two Nobel Prizes?", + "Which country invented paper?", + "What organ is primarily responsible for pumping blood throughout the body?", + "Which planet is known for its prominent ring system?", + "Who directed the movie 'Inception'?", + "What is the freezing point of water in Fahrenheit?", + "Which animal is known to have the longest lifespan?", + "What language has the most native speakers worldwide?", + "What is the capital city of Canada?", + "Who is credited with inventing the World Wide Web?", + "Which metal is liquid at room temperature?", + "What is the term for an animal that eats both plants and meat?", + "Who painted 'The Starry Night'?", + "What gas do humans exhale that plants use for photosynthesis?", + "What year did World War II end?", + "Which continent has the most countries?", + "Who wrote the novel 'Frankenstein'?", + "What does DNA stand for?", + "What is the main ingredient in traditional Japanese miso soup?" +}; + +static std::vector<std::string> k_answers = { + "The tallest mountain in the world is Mount Everest.", + "Marie Curie was the first person to win two Nobel Prizes.", + "Paper was invented in China.", + "The heart is the organ responsible for pumping blood.", + "Saturn is known for its prominent ring system.", + "Christopher Nolan directed the movie 'Inception'.", + "The freezing point of water in Fahrenheit is 32°F.", + "The bowhead whale is known to have the longest lifespan among mammals.", + "Mandarin Chinese has the most native speakers in the world.", + "The capital city of Canada is Ottawa.", + "Tim Berners-Lee is credited with inventing the World Wide Web.", + "Mercury is the metal that is liquid at room temperature.", + "An animal that eats both plants and meat is called an omnivore.", + "'The Starry Night' was painted by Vincent van Gogh.", + "Humans exhale carbon dioxide, which plants use in photosynthesis.", + "World War II ended in 1945.", + "Africa is the continent with the most countries.", + "The novel 'Frankenstein' was written by Mary Shelley.", + "DNA stands for Deoxyribonucleic Acid.", + "The main ingredient in traditional Japanese miso soup is fermented soybean paste." +}; + +static std::vector<std::string> k_prompts = { + "What is the meaning of life?", + "Tell me an interesting fact about llamas.", + "What is the best way to cook a steak?", + "Are you familiar with the Special Theory of Relativity and can you explain it to me?", + "Recommend some interesting books to read.", + "What is the best way to learn a new language?", + "How to get a job at Google?", + "If you could have any superpower, what would it be?", + "I want to learn how to play the piano. What would be the best way to do it?", +}; + +struct client { + ~client() { + if (smpl) { + common_sampler_free(smpl); + } + } + + int32_t id = 0; + + llama_seq_id seq_id = -1; + + llama_token sampled; + + int64_t t_start_prompt; + int64_t t_start_gen; + + int32_t n_past = 0; + int32_t n_prompt = 0; + int32_t n_decoded = 0; + int32_t i_batch = -1; + + std::string input; + std::string prompt; + std::string response; + + struct common_sampler * smpl = nullptr; +}; + +static void print_date_time() { + std::time_t current_time = std::time(nullptr); + std::tm* local_time = std::localtime(¤t_time); + char buffer[80]; + strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", local_time); + + LOG_INF("\n"); + LOG_INF("\033[35mrun parameters as of %s\033[0m\n", buffer); + LOG_INF("\n"); +} + +// Define a split string function to ... +static std::vector<std::string> split_string(const std::string& input, char delimiter) { + std::vector<std::string> tokens; + std::istringstream stream(input); + std::string token; + while (std::getline(stream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +int main(int argc, char ** argv) { + srand(1234); + + common_params params; + + params.n_predict = 128; + params.n_junk = 1; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) { + return 1; + } + + common_init(); + + // number of simultaneous "clients" to simulate + const int32_t n_clients = params.n_parallel; + + // dedicate one sequence to the system prompt + params.n_parallel += 1; + + // requests to simulate + const int32_t n_seq = params.n_sequences; + + // insert new requests as soon as the previous one is done + const bool cont_batching = params.cont_batching; + + // is the system prompt shared in the cache + const bool is_sp_shared = params.is_pp_shared; + + // extra text to insert in each client's prompt in order to make it larger + const int32_t n_junk = std::max(1, params.n_junk); + + // signed seed, use negative values to indicate different seeds for the different clients + const int32_t & sseed = params.sampling.seed; + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + // load the target model + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + auto * mem = llama_get_memory(ctx); + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // load the prompts from an external file if there are any + if (params.prompt.empty()) { + LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); + } else { + // Output each line of the input params.prompts vector and copy to k_prompts + int index = 0; + LOG_INF("\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str()); + + std::vector<std::string> prompts = split_string(params.prompt, '\n'); + for (const auto& prompt : prompts) { + k_prompts.resize(index + 1); + k_prompts[index] = prompt; + index++; + LOG_INF("%3d prompt: %s\n", index, prompt.c_str()); + } + } + + LOG_INF("\n\n"); + + const int n_ctx = llama_n_ctx(ctx); + + if (sseed >= 0) { + LOG_INF("%s: initializing all samplers with the same RNG seed: %d (use a negative seed to have different seeds)\n", __func__, sseed); + } else { + LOG_INF("%s: initializing samplers with different RNG seeds, starting from %d\n", __func__, sseed); + } + + std::vector<client> clients(n_clients); + for (size_t i = 0; i < clients.size(); ++i) { + auto & client = clients[i]; + client.id = i; + client.smpl = common_sampler_init(model, params.sampling); + + if (sseed < 0) { + params.sampling.seed--; + } + } + + std::vector<llama_token> tokens_system; + + tokens_system = common_tokenize(ctx, k_system, true); + const int32_t n_tokens_system = tokens_system.size(); + + llama_seq_id g_seq_id = 0; + + // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple + // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time + llama_batch batch = llama_batch_init(n_ctx, 0, 1); + + int32_t n_total_prompt = 0; + int32_t n_total_gen = 0; + int32_t n_cache_miss = 0; + + const auto t_main_start = ggml_time_us(); + + LOG_INF("%s: Simulating parallel requests from clients:\n", __func__); + LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); + LOG_INF("\n"); + + if (is_sp_shared) { + LOG_INF("%s: Evaluating the system prompt ...\n", __func__); + + for (int32_t i = 0; i < n_tokens_system; ++i) { + common_batch_add(batch, tokens_system[i], i, { 0 }, false); + } + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + // assign the system KV cache to all parallel sequences + for (int32_t i = 1; i <= n_clients; ++i) { + llama_memory_seq_cp(mem, 0, i, -1, -1); + } + + LOG_INF("\n"); + } + + LOG_INF("Processing requests ...\n\n"); + + while (true) { + common_batch_clear(batch); + + // decode any currently ongoing sequences + for (auto & client : clients) { + if (client.seq_id == -1) { + continue; + } + + client.i_batch = batch.n_tokens; + + common_batch_add(batch, client.sampled, client.n_past++, { client.id + 1 }, true); + + client.n_decoded += 1; + } + + if (batch.n_tokens == 0) { + // all sequences have ended - clear the entire KV cache + for (int i = 1; i <= n_clients; ++i) { + llama_memory_seq_rm(mem, i, -1, -1); + // but keep the system prompt + llama_memory_seq_cp(mem, 0, i, -1, -1); + } + + LOG_INF("%s: clearing the KV cache\n", __func__); + } + + // insert new sequences for decoding + if (cont_batching || batch.n_tokens == 0) { + for (auto & client : clients) { + if (client.seq_id == -1 && g_seq_id < n_seq) { + client.seq_id = g_seq_id; + + client.t_start_prompt = ggml_time_us(); + client.t_start_gen = 0; + + client.input = k_prompts[rand() % k_prompts.size()]; + client.response = ""; + + // construct the prompt: + // [system prompt] + [junk] + [user prompt] + client.n_past = 0; + client.prompt = ""; + if (is_sp_shared) { + client.n_past = n_tokens_system; + } else { + client.prompt += k_system; + } + + const int n_junk_cur = rand() % n_junk; + + for (int i = 0; i < n_junk_cur; ++i) { + const int r = rand() % k_questions.size(); + client.prompt += "User:\n" + k_questions[r] + "\nAssistant:\n " + k_answers[r] + "\n"; + } + client.prompt += "User:\n" + client.input + "\nAssistant:\n"; + + common_sampler_reset(client.smpl); + + // do not prepend BOS because we have a system prompt! + std::vector<llama_token> tokens_prompt; + tokens_prompt = common_tokenize(ctx, client.prompt, false); + + for (size_t i = 0; i < tokens_prompt.size(); ++i) { + common_batch_add(batch, tokens_prompt[i], client.n_past++, { client.id + 1 }, false); + } + + // extract the logits only for the last token + if (batch.n_tokens > 0) { + batch.logits[batch.n_tokens - 1] = true; + } + + client.n_prompt = tokens_prompt.size(); + client.n_decoded = 0; + client.i_batch = batch.n_tokens - 1; + + LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt); + + g_seq_id += 1; + + // insert new requests one-by-one + //if (cont_batching) { + // break; + //} + } + } + } + + if (batch.n_tokens == 0) { + break; + } + + // process in chunks of params.n_batch + int32_t n_batch = params.n_batch; + + int32_t i_next = 0; + + for (int32_t i = 0; i < batch.n_tokens; i = i_next) { + // experiment: process in powers of 2 + //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { + // n_batch /= 2; + // i -= n_batch; + // continue; + //} + + const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + if (n_batch == 1 || ret < 0) { + // if you get here, it means the KV cache is full - try increasing it via the context size + LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); + return 1; + } + + LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2); + + n_cache_miss += 1; + + // retry with half the batch size to try to find a free slot in the KV cache + n_batch /= 2; + + continue; + } + + LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens); + + // move the head of the batch forward with the number of tokens we just processed + i_next = i + n_tokens; + + // on successful decode, restore the original batch size + n_batch = params.n_batch; + + for (auto & client : clients) { + if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) { + continue; + } + + //printf("client %d, seq %d, token %d, pos %d, batch %d\n", + // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); + + const llama_token id = common_sampler_sample(client.smpl, ctx, client.i_batch - i); + + common_sampler_accept(client.smpl, id, true); + + if (client.n_decoded == 1) { + // start measuring generation time after the first token to make sure all concurrent clients + // have their prompt already processed + client.t_start_gen = ggml_time_us(); + } + + const std::string token_str = common_token_to_piece(ctx, id); + + client.response += token_str; + client.sampled = id; + + //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n", + // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); + + if (client.n_decoded > 2 && + (llama_vocab_is_eog(vocab, id) || + (params.n_predict > 0 && client.n_decoded >= params.n_predict) || + client.response.find("User:") != std::string::npos)) { + // basic reverse prompt + const size_t pos = client.response.find("User:"); + if (pos != std::string::npos) { + client.response = client.response.substr(0, pos); + } + + // delete only the generated part of the sequence, i.e. keep the system prompt in the cache + llama_memory_seq_rm(mem, client.id + 1, -1, -1); + llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1); + + const auto t_main_end = ggml_time_us(); + + LOG_INF("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\n\033[35mResponse: %s\033[0m\n\n", + client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded, + (t_main_end - client.t_start_prompt) / 1e6, + (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6, + n_cache_miss, + ::trim(client.input).c_str(), + ::trim(client.response).c_str()); + + n_total_prompt += client.n_prompt; + n_total_gen += client.n_decoded; + + client.seq_id = -1; + } + + client.i_batch = -1; + } + } + } + + const auto t_main_end = ggml_time_us(); + + print_date_time(); + + LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); + if (params.prompt_file.empty()) { + params.prompt_file = "used built-in defaults"; + } + LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str()); + LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.path.c_str()); + + LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6); + LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6); + LOG_INF("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6); + LOG_INF("Cache misses: %6d\n", n_cache_miss); + + LOG_INF("\n"); + + // TODO: print sampling/grammar timings for all clients + llama_perf_context_print(ctx); + + llama_batch_free(batch); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/llama.cpp/examples/passkey/CMakeLists.txt b/llama.cpp/examples/passkey/CMakeLists.txt new file mode 100644 index 0000000..9bc5110 --- /dev/null +++ b/llama.cpp/examples/passkey/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-passkey) +add_executable(${TARGET} passkey.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/passkey/README.md b/llama.cpp/examples/passkey/README.md new file mode 100644 index 0000000..cbaf28f --- /dev/null +++ b/llama.cpp/examples/passkey/README.md @@ -0,0 +1,15 @@ +# llama.cpp/example/passkey + +A passkey retrieval task is an evaluation method used to measure a language +models ability to recall information from long contexts. + +See the following PRs for more info: + +- https://github.com/ggml-org/llama.cpp/pull/3856 +- https://github.com/ggml-org/llama.cpp/pull/4810 + +### Usage + +```bash +llama-passkey -m ./models/llama-7b-v2/ggml-model-f16.gguf --junk 250 +``` diff --git a/llama.cpp/examples/passkey/passkey.cpp b/llama.cpp/examples/passkey/passkey.cpp new file mode 100644 index 0000000..8a4faa3 --- /dev/null +++ b/llama.cpp/examples/passkey/passkey.cpp @@ -0,0 +1,274 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include <cmath> +#include <cstdio> +#include <string> +#include <vector> +#include <algorithm> + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]); + LOG("\n"); +} + +int main(int argc, char ** argv) { + common_params params; + + params.n_junk = 250; + params.n_keep = 32; + params.i_pos = -1; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) { + return 1; + } + + common_init(); + + int n_junk = params.n_junk; + int n_keep = params.n_keep; + int n_grp = params.grp_attn_n; + int i_pos = params.i_pos; + + if (i_pos == -1) { + i_pos = rand() % n_junk; + } + + const std::string prompt_prefix = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."; + const std::string prompt_suffix = " What is the pass key? The pass key is"; + + // generate junk text + params.prompt = prompt_prefix; + + const int passkey = rand() % 50000 + 1; + + for (int i = 0; i < n_junk; i++) { + if (i % n_junk == i_pos) { + params.prompt += " The pass key is " + std::to_string(passkey) + ". Remember it. " + std::to_string(passkey) + " is the pass key."; + } + + params.prompt += " The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."; + } + + params.prompt += prompt_suffix; + + // init LLM + + llama_backend_init(); + llama_numa_init(params.numa); + + // initialize the model + + llama_model_params model_params = common_model_params_to_llama(params); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params); + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // initialize the context + + llama_context_params ctx_params = common_context_params_to_llama(params); + + ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep; + + GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (ctx == NULL) { + LOG_ERR("%s: failed to create the llama_context\n" , __func__); + return 1; + } + + auto sparams = llama_sampler_chain_default_params(); + + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + + // tokenize the prompt + std::vector<llama_token> tokens_list; + tokens_list = common_tokenize(ctx, params.prompt, true); + + // tokenize the prefix and use it as a sink + const int n_tokens_prefix = common_tokenize(ctx, prompt_prefix, true).size(); + + const int n_tokens_all = tokens_list.size(); + + // we leave a margin of 16 tokens for the generated text - it should contain just the passkey + const int n_predict = 16; + + // total length of the sequences including the prompt + const int n_len = n_tokens_all + n_predict; + + const int n_ctx = llama_n_ctx(ctx) - n_keep; + const int n_kv_req = llama_n_ctx(ctx); + const int n_batch = ctx_params.n_batch; + const int n_batch_grp = ctx_params.n_batch/n_grp; + + LOG_INF("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos); + + // print the prompt token-by-token + + LOG_INF("\n"); + LOG_INF("prefix tokens: %d\n", n_tokens_prefix); + LOG_INF("prompt tokens: %d\n", n_tokens_all); + //LOG_INF("prompt: %s\n", params.prompt.c_str()); + + llama_batch batch = llama_batch_init(params.n_batch, 0, 1); + + int n_past = 0; + + auto * mem = llama_get_memory(ctx); + + // fill the KV cache + for (int i = 0; i < n_ctx; i += n_batch) { + if (i > 0 && n_grp > 1) { + // if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp + const int ib = i/n_batch - 1; + const int bd = n_batch_grp*(n_grp - 1); + + llama_memory_seq_add(mem, 0, n_past - n_batch, n_past, ib*bd); + llama_memory_seq_div(mem, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + + n_past = llama_memory_seq_pos_max(mem, 0) + 1; + } + + common_batch_clear(batch); + + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { + common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + } + + if (i + n_batch >= n_tokens_all) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(ctx, batch) != 0) { + LOG_INF("%s: llama_decode() failed\n", __func__); + return 1; + } + + LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); + + if (i + n_batch >= n_tokens_all) { + break; + } + } + + for (int i = n_ctx; i < n_tokens_all; i += n_batch) { + const int n_discard = n_batch; + + LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard); + + llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard); + llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard); + + n_past = llama_memory_seq_pos_max(mem, 0) + 1; + + common_batch_clear(batch); + + for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { + common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + } + + if (i + n_batch >= n_tokens_all) { + batch.logits[batch.n_tokens - 1] = true; + } + + if (llama_decode(ctx, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); + } + + { + const int n_discard = n_past - n_ctx + n_predict; + + if (n_discard > 0) { + LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); + + llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard); + llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard); + + n_past = llama_memory_seq_pos_max(mem, 0) + 1; + } + } + + LOG_INF("\n"); + LOG_INF("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk); + LOG_INF("\n"); + + // main loop + + int n_cur = n_tokens_all; + int n_decode = 0; + + LOG_INF("%s", prompt_suffix.c_str()); + + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { + // sample the next token + { + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); + + // is it an end of generation? + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { + LOG("\n"); + + break; + } + + LOG("%s", common_token_to_piece(ctx, new_token_id).c_str()); + + n_decode += 1; + + // prepare the next batch + common_batch_clear(batch); + + // push this new token for next evaluation + common_batch_add(batch, new_token_id, n_past++, { 0 }, true); + } + + n_cur += 1; + + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + } + + LOG("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + LOG("\n"); + llama_perf_context_print(ctx); + + LOG("\n"); + + llama_sampler_free(smpl); + + llama_batch_free(batch); + + llama_free(ctx); + llama_model_free(model); + + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/pydantic_models_to_grammar.py b/llama.cpp/examples/pydantic_models_to_grammar.py new file mode 100644 index 0000000..93e5dcb --- /dev/null +++ b/llama.cpp/examples/pydantic_models_to_grammar.py @@ -0,0 +1,1322 @@ +from __future__ import annotations + +import inspect +import json +import re +from copy import copy +from enum import Enum +from inspect import getdoc, isclass +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints + +from docstring_parser import parse +from pydantic import BaseModel, create_model + +if TYPE_CHECKING: + from types import GenericAlias +else: + # python 3.8 compat + from typing import _GenericAlias as GenericAlias + +# TODO: fix this +# pyright: reportAttributeAccessIssue=information + + +class PydanticDataType(Enum): + """ + Defines the data types supported by the grammar_generator. + + Attributes: + STRING (str): Represents a string data type. + BOOLEAN (str): Represents a boolean data type. + INTEGER (str): Represents an integer data type. + FLOAT (str): Represents a float data type. + OBJECT (str): Represents an object data type. + ARRAY (str): Represents an array data type. + ENUM (str): Represents an enum data type. + CUSTOM_CLASS (str): Represents a custom class data type. + """ + + STRING = "string" + TRIPLE_QUOTED_STRING = "triple_quoted_string" + MARKDOWN_CODE_BLOCK = "markdown_code_block" + BOOLEAN = "boolean" + INTEGER = "integer" + FLOAT = "float" + OBJECT = "object" + ARRAY = "array" + ENUM = "enum" + ANY = "any" + NULL = "null" + CUSTOM_CLASS = "custom-class" + CUSTOM_DICT = "custom-dict" + SET = "set" + + +def map_pydantic_type_to_gbnf(pydantic_type: type[Any]) -> str: + origin_type = get_origin(pydantic_type) + origin_type = pydantic_type if origin_type is None else origin_type + + if isclass(origin_type) and issubclass(origin_type, str): + return PydanticDataType.STRING.value + elif isclass(origin_type) and issubclass(origin_type, bool): + return PydanticDataType.BOOLEAN.value + elif isclass(origin_type) and issubclass(origin_type, int): + return PydanticDataType.INTEGER.value + elif isclass(origin_type) and issubclass(origin_type, float): + return PydanticDataType.FLOAT.value + elif isclass(origin_type) and issubclass(origin_type, Enum): + return PydanticDataType.ENUM.value + + elif isclass(origin_type) and issubclass(origin_type, BaseModel): + return format_model_and_field_name(origin_type.__name__) + elif origin_type is list: + element_type = get_args(pydantic_type)[0] + return f"{map_pydantic_type_to_gbnf(element_type)}-list" + elif origin_type is set: + element_type = get_args(pydantic_type)[0] + return f"{map_pydantic_type_to_gbnf(element_type)}-set" + elif origin_type is Union: + union_types = get_args(pydantic_type) + union_rules = [map_pydantic_type_to_gbnf(ut) for ut in union_types] + return f"union-{'-or-'.join(union_rules)}" + elif origin_type is Optional: + element_type = get_args(pydantic_type)[0] + return f"optional-{map_pydantic_type_to_gbnf(element_type)}" + elif isclass(origin_type): + return f"{PydanticDataType.CUSTOM_CLASS.value}-{format_model_and_field_name(origin_type.__name__)}" + elif origin_type is dict: + key_type, value_type = get_args(pydantic_type) + return f"custom-dict-key-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(key_type))}-value-type-{format_model_and_field_name(map_pydantic_type_to_gbnf(value_type))}" + else: + return "unknown" + + +def format_model_and_field_name(model_name: str) -> str: + parts = re.findall("[A-Z][^A-Z]*", model_name) + if not parts: # Check if the list is empty + return model_name.lower().replace("_", "-") + return "-".join(part.lower().replace("_", "-") for part in parts) + + +def generate_list_rule(element_type): + """ + Generate a GBNF rule for a list of a given element type. + + :param element_type: The type of the elements in the list (e.g., 'string'). + :return: A string representing the GBNF rule for a list of the given type. + """ + rule_name = f"{map_pydantic_type_to_gbnf(element_type)}-list" + element_rule = map_pydantic_type_to_gbnf(element_type) + list_rule = rf'{rule_name} ::= "[" {element_rule} ("," {element_rule})* "]"' + return list_rule + + +def get_members_structure(cls, rule_name): + if issubclass(cls, Enum): + # Handle Enum types + members = [f'"\\"{member.value}\\""' for name, member in cls.__members__.items()] + return f"{cls.__name__.lower()} ::= " + " | ".join(members) + if cls.__annotations__ and cls.__annotations__ != {}: + result = f'{rule_name} ::= "{{"' + # Modify this comprehension + members = [ + f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param_type)}' + for name, param_type in get_type_hints(cls).items() + if name != "self" + ] + + result += '"," '.join(members) + result += ' "}"' + return result + if rule_name == "custom-class-any": + result = f"{rule_name} ::= " + result += "value" + return result + + init_signature = inspect.signature(cls.__init__) + parameters = init_signature.parameters + result = f'{rule_name} ::= "{{"' + # Modify this comprehension too + members = [ + f' "\\"{name}\\"" ":" {map_pydantic_type_to_gbnf(param.annotation)}' + for name, param in parameters.items() + if name != "self" and param.annotation != inspect.Parameter.empty + ] + + result += '", "'.join(members) + result += ' "}"' + return result + + +def regex_to_gbnf(regex_pattern: str) -> str: + """ + Translate a basic regex pattern to a GBNF rule. + Note: This function handles only a subset of simple regex patterns. + """ + gbnf_rule = regex_pattern + + # Translate common regex components to GBNF + gbnf_rule = gbnf_rule.replace("\\d", "[0-9]") + gbnf_rule = gbnf_rule.replace("\\s", "[ \t\n]") + + # Handle quantifiers and other regex syntax that is similar in GBNF + # (e.g., '*', '+', '?', character classes) + + return gbnf_rule + + +def generate_gbnf_integer_rules(max_digit=None, min_digit=None): + """ + + Generate GBNF Integer Rules + + Generates GBNF (Generalized Backus-Naur Form) rules for integers based on the given maximum and minimum digits. + + Parameters: + max_digit (int): The maximum number of digits for the integer. Default is None. + min_digit (int): The minimum number of digits for the integer. Default is None. + + Returns: + integer_rule (str): The identifier for the integer rule generated. + additional_rules (list): A list of additional rules generated based on the given maximum and minimum digits. + + """ + additional_rules = [] + + # Define the rule identifier based on max_digit and min_digit + integer_rule = "integer-part" + if max_digit is not None: + integer_rule += f"-max{max_digit}" + if min_digit is not None: + integer_rule += f"-min{min_digit}" + + # Handling Integer Rules + if max_digit is not None or min_digit is not None: + # Start with an empty rule part + integer_rule_part = "" + + # Add mandatory digits as per min_digit + if min_digit is not None: + integer_rule_part += "[0-9] " * min_digit + + # Add optional digits up to max_digit + if max_digit is not None: + optional_digits = max_digit - (min_digit if min_digit is not None else 0) + integer_rule_part += "".join(["[0-9]? " for _ in range(optional_digits)]) + + # Trim the rule part and append it to additional rules + integer_rule_part = integer_rule_part.strip() + if integer_rule_part: + additional_rules.append(f"{integer_rule} ::= {integer_rule_part}") + + return integer_rule, additional_rules + + +def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None, min_precision=None): + """ + Generate GBNF float rules based on the given constraints. + + :param max_digit: Maximum number of digits in the integer part (default: None) + :param min_digit: Minimum number of digits in the integer part (default: None) + :param max_precision: Maximum number of digits in the fractional part (default: None) + :param min_precision: Minimum number of digits in the fractional part (default: None) + :return: A tuple containing the float rule and additional rules as a list + + Example Usage: + max_digit = 3 + min_digit = 1 + max_precision = 2 + min_precision = 1 + generate_gbnf_float_rules(max_digit, min_digit, max_precision, min_precision) + + Output: + ('float-3-1-2-1', ['integer-part-max3-min1 ::= [0-9] [0-9] [0-9]?', 'fractional-part-max2-min1 ::= [0-9] [0-9]?', 'float-3-1-2-1 ::= integer-part-max3-min1 "." fractional-part-max2-min + *1']) + + Note: + GBNF stands for Generalized Backus-Naur Form, which is a notation technique to specify the syntax of programming languages or other formal grammars. + """ + additional_rules = [] + + # Define the integer part rule + integer_part_rule = ( + "integer-part" + + (f"-max{max_digit}" if max_digit is not None else "") + + (f"-min{min_digit}" if min_digit is not None else "") + ) + + # Define the fractional part rule based on precision constraints + fractional_part_rule = "fractional-part" + fractional_rule_part = "" + if max_precision is not None or min_precision is not None: + fractional_part_rule += (f"-max{max_precision}" if max_precision is not None else "") + ( + f"-min{min_precision}" if min_precision is not None else "" + ) + # Minimum number of digits + fractional_rule_part = "[0-9]" * (min_precision if min_precision is not None else 1) + # Optional additional digits + fractional_rule_part += "".join( + [" [0-9]?"] * ((max_precision - ( + min_precision if min_precision is not None else 1)) if max_precision is not None else 0) + ) + additional_rules.append(f"{fractional_part_rule} ::= {fractional_rule_part}") + + # Define the float rule + float_rule = f"float-{max_digit if max_digit is not None else 'X'}-{min_digit if min_digit is not None else 'X'}-{max_precision if max_precision is not None else 'X'}-{min_precision if min_precision is not None else 'X'}" + additional_rules.append(f'{float_rule} ::= {integer_part_rule} "." {fractional_part_rule}') + + # Generating the integer part rule definition, if necessary + if max_digit is not None or min_digit is not None: + integer_rule_part = "[0-9]" + if min_digit is not None and min_digit > 1: + integer_rule_part += " [0-9]" * (min_digit - 1) + if max_digit is not None: + integer_rule_part += "".join([" [0-9]?"] * (max_digit - (min_digit if min_digit is not None else 1))) + additional_rules.append(f"{integer_part_rule} ::= {integer_rule_part.strip()}") + + return float_rule, additional_rules + + +def generate_gbnf_rule_for_type( + model_name, field_name, field_type, is_optional, processed_models, created_rules, field_info=None +) -> tuple[str, list[str]]: + """ + Generate GBNF rule for a given field type. + + :param model_name: Name of the model. + + :param field_name: Name of the field. + :param field_type: Type of the field. + :param is_optional: Whether the field is optional. + :param processed_models: List of processed models. + :param created_rules: List of created rules. + :param field_info: Additional information about the field (optional). + + :return: Tuple containing the GBNF type and a list of additional rules. + :rtype: tuple[str, list] + """ + rules = [] + + field_name = format_model_and_field_name(field_name) + gbnf_type = map_pydantic_type_to_gbnf(field_type) + + origin_type = get_origin(field_type) + origin_type = field_type if origin_type is None else origin_type + + if isclass(origin_type) and issubclass(origin_type, BaseModel): + nested_model_name = format_model_and_field_name(field_type.__name__) + nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules) + rules.extend(nested_model_rules) + gbnf_type, rules = nested_model_name, rules + elif isclass(origin_type) and issubclass(origin_type, Enum): + enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes + enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" + rules.append(enum_rule) + gbnf_type, rules = model_name + "-" + field_name, rules + elif origin_type is list: # Array + element_type = get_args(field_type)[0] + element_rule_name, additional_rules = generate_gbnf_rule_for_type( + model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules + ) + rules.extend(additional_rules) + array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """ + rules.append(array_rule) + gbnf_type, rules = model_name + "-" + field_name, rules + + elif origin_type is set: # Array + element_type = get_args(field_type)[0] + element_rule_name, additional_rules = generate_gbnf_rule_for_type( + model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules + ) + rules.extend(additional_rules) + array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """ + rules.append(array_rule) + gbnf_type, rules = model_name + "-" + field_name, rules + + elif gbnf_type.startswith("custom-class-"): + rules.append(get_members_structure(field_type, gbnf_type)) + elif gbnf_type.startswith("custom-dict-"): + key_type, value_type = get_args(field_type) + + additional_key_type, additional_key_rules = generate_gbnf_rule_for_type( + model_name, f"{field_name}-key-type", key_type, is_optional, processed_models, created_rules + ) + additional_value_type, additional_value_rules = generate_gbnf_rule_for_type( + model_name, f"{field_name}-value-type", value_type, is_optional, processed_models, created_rules + ) + gbnf_type = rf'{gbnf_type} ::= "{{" ( {additional_key_type} ": " {additional_value_type} ("," "\n" ws {additional_key_type} ":" {additional_value_type})* )? "}}" ' + + rules.extend(additional_key_rules) + rules.extend(additional_value_rules) + elif gbnf_type.startswith("union-"): + union_types = get_args(field_type) + union_rules = [] + + for union_type in union_types: + if isinstance(union_type, GenericAlias): + union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type( + model_name, field_name, union_type, False, processed_models, created_rules + ) + union_rules.append(union_gbnf_type) + rules.extend(union_rules_list) + + elif not issubclass(union_type, type(None)): + union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type( + model_name, field_name, union_type, False, processed_models, created_rules + ) + union_rules.append(union_gbnf_type) + rules.extend(union_rules_list) + + # Defining the union grammar rule separately + if len(union_rules) == 1: + union_grammar_rule = f"{model_name}-{field_name}-optional ::= {' | '.join(union_rules)} | null" + else: + union_grammar_rule = f"{model_name}-{field_name}-union ::= {' | '.join(union_rules)}" + rules.append(union_grammar_rule) + if len(union_rules) == 1: + gbnf_type = f"{model_name}-{field_name}-optional" + else: + gbnf_type = f"{model_name}-{field_name}-union" + elif isclass(origin_type) and issubclass(origin_type, str): + if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None: + triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False) + markdown_string = field_info.json_schema_extra.get("markdown_code_block", False) + + gbnf_type = PydanticDataType.TRIPLE_QUOTED_STRING.value if triple_quoted_string else PydanticDataType.STRING.value + gbnf_type = PydanticDataType.MARKDOWN_CODE_BLOCK.value if markdown_string else gbnf_type + + elif field_info and hasattr(field_info, "pattern"): + # Convert regex pattern to grammar rule + regex_pattern = field_info.regex.pattern + gbnf_type = f"pattern-{field_name} ::= {regex_to_gbnf(regex_pattern)}" + else: + gbnf_type = PydanticDataType.STRING.value + + elif ( + isclass(origin_type) + and issubclass(origin_type, float) + and field_info + and hasattr(field_info, "json_schema_extra") + and field_info.json_schema_extra is not None + ): + # Retrieve precision attributes for floats + max_precision = ( + field_info.json_schema_extra.get("max_precision") if field_info and hasattr(field_info, + "json_schema_extra") else None + ) + min_precision = ( + field_info.json_schema_extra.get("min_precision") if field_info and hasattr(field_info, + "json_schema_extra") else None + ) + max_digits = field_info.json_schema_extra.get("max_digit") if field_info and hasattr(field_info, + "json_schema_extra") else None + min_digits = field_info.json_schema_extra.get("min_digit") if field_info and hasattr(field_info, + "json_schema_extra") else None + + # Generate GBNF rule for float with given attributes + gbnf_type, rules = generate_gbnf_float_rules( + max_digit=max_digits, min_digit=min_digits, max_precision=max_precision, min_precision=min_precision + ) + + elif ( + isclass(origin_type) + and issubclass(origin_type, int) + and field_info + and hasattr(field_info, "json_schema_extra") + and field_info.json_schema_extra is not None + ): + # Retrieve digit attributes for integers + max_digits = field_info.json_schema_extra.get("max_digit") if field_info and hasattr(field_info, + "json_schema_extra") else None + min_digits = field_info.json_schema_extra.get("min_digit") if field_info and hasattr(field_info, + "json_schema_extra") else None + + # Generate GBNF rule for integer with given attributes + gbnf_type, rules = generate_gbnf_integer_rules(max_digit=max_digits, min_digit=min_digits) + else: + gbnf_type, rules = gbnf_type, [] + + return gbnf_type, rules + + +def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[BaseModel]], created_rules: dict[str, list[str]]) -> tuple[list[str], bool]: + """ + + Generate GBnF Grammar + + Generates a GBnF grammar for a given model. + + :param model: A Pydantic model class to generate the grammar for. Must be a subclass of BaseModel. + :param processed_models: A set of already processed models to prevent infinite recursion. + :param created_rules: A dict containing already created rules to prevent duplicates. + :return: A list of GBnF grammar rules in string format. And two booleans indicating if an extra markdown or triple quoted string is in the grammar. + Example Usage: + ``` + model = MyModel + processed_models = set() + created_rules = dict() + + gbnf_grammar = generate_gbnf_grammar(model, processed_models, created_rules) + ``` + """ + if model in processed_models: + return [], False + + processed_models.add(model) + model_name = format_model_and_field_name(model.__name__) + + if not issubclass(model, BaseModel): + # For non-Pydantic classes, generate model_fields from __annotations__ or __init__ + if hasattr(model, "__annotations__") and model.__annotations__: + model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()} + else: + init_signature = inspect.signature(model.__init__) + parameters = init_signature.parameters + model_fields = {name: (param.annotation, param.default) for name, param in parameters.items() if + name != "self"} + else: + # For Pydantic models, use model_fields and check for ellipsis (required fields) + model_fields = get_type_hints(model) + + model_rule_parts = [] + nested_rules = [] + has_markdown_code_block = False + has_triple_quoted_string = False + look_for_markdown_code_block = False + look_for_triple_quoted_string = False + for field_name, field_info in model_fields.items(): + if not issubclass(model, BaseModel): + field_type, default_value = field_info + # Check if the field is optional (not required) + is_optional = (default_value is not inspect.Parameter.empty) and (default_value is not Ellipsis) + else: + field_type = field_info + field_info = model.model_fields[field_name] + is_optional = field_info.is_required is False and get_origin(field_type) is Optional + rule_name, additional_rules = generate_gbnf_rule_for_type( + model_name, format_model_and_field_name(field_name), field_type, is_optional, processed_models, + created_rules, field_info + ) + look_for_markdown_code_block = True if rule_name == "markdown_code_block" else False + look_for_triple_quoted_string = True if rule_name == "triple_quoted_string" else False + if not look_for_markdown_code_block and not look_for_triple_quoted_string: + if rule_name not in created_rules: + created_rules[rule_name] = additional_rules + model_rule_parts.append(f' ws "\\"{field_name}\\"" ":" ws {rule_name}') # Adding escaped quotes + nested_rules.extend(additional_rules) + else: + has_triple_quoted_string = look_for_triple_quoted_string + has_markdown_code_block = look_for_markdown_code_block + + fields_joined = r' "," "\n" '.join(model_rule_parts) + model_rule = rf'{model_name} ::= "{{" "\n" {fields_joined} "\n" ws "}}"' + + has_special_string = False + if has_triple_quoted_string: + model_rule += '"\\n" ws "}"' + model_rule += '"\\n" triple-quoted-string' + has_special_string = True + if has_markdown_code_block: + model_rule += '"\\n" ws "}"' + model_rule += '"\\n" markdown-code-block' + has_special_string = True + all_rules = [model_rule] + nested_rules + + return all_rules, has_special_string + + +def generate_gbnf_grammar_from_pydantic_models( + models: list[type[BaseModel]], outer_object_name: str | None = None, outer_object_content: str | None = None, + list_of_outputs: bool = False +) -> str: + """ + Generate GBNF Grammar from Pydantic Models. + + This method takes a list of Pydantic models and uses them to generate a GBNF grammar string. The generated grammar string can be used for parsing and validating data using the generated + * grammar. + + Args: + models (list[type[BaseModel]]): A list of Pydantic models to generate the grammar from. + outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling. + outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling. + list_of_outputs (str, optional): Allows a list of output objects + Returns: + str: The generated GBNF grammar string. + + Examples: + models = [UserModel, PostModel] + grammar = generate_gbnf_grammar_from_pydantic(models) + print(grammar) + # Output: + # root ::= UserModel | PostModel + # ... + """ + processed_models: set[type[BaseModel]] = set() + all_rules = [] + created_rules: dict[str, list[str]] = {} + if outer_object_name is None: + for model in models: + model_rules, _ = generate_gbnf_grammar(model, processed_models, created_rules) + all_rules.extend(model_rules) + + if list_of_outputs: + root_rule = r'root ::= (" "| "\n") "[" ws grammar-models ("," ws grammar-models)* ws "]"' + "\n" + else: + root_rule = r'root ::= (" "| "\n") grammar-models' + "\n" + root_rule += "grammar-models ::= " + " | ".join( + [format_model_and_field_name(model.__name__) for model in models]) + all_rules.insert(0, root_rule) + return "\n".join(all_rules) + elif outer_object_name is not None: + if list_of_outputs: + root_rule = ( + rf'root ::= (" "| "\n") "[" ws {format_model_and_field_name(outer_object_name)} ("," ws {format_model_and_field_name(outer_object_name)})* ws "]"' + + "\n" + ) + else: + root_rule = f"root ::= {format_model_and_field_name(outer_object_name)}\n" + + model_rule = ( + rf'{format_model_and_field_name(outer_object_name)} ::= (" "| "\n") "{{" ws "\"{outer_object_name}\"" ":" ws grammar-models' + ) + + fields_joined = " | ".join( + [rf"{format_model_and_field_name(model.__name__)}-grammar-model" for model in models]) + + grammar_model_rules = f"\ngrammar-models ::= {fields_joined}" + mod_rules = [] + for model in models: + mod_rule = rf"{format_model_and_field_name(model.__name__)}-grammar-model ::= " + mod_rule += ( + rf'"\"{model.__name__}\"" "," ws "\"{outer_object_content}\"" ":" ws {format_model_and_field_name(model.__name__)}' + "\n" + ) + mod_rules.append(mod_rule) + grammar_model_rules += "\n" + "\n".join(mod_rules) + + for model in models: + model_rules, has_special_string = generate_gbnf_grammar(model, processed_models, + created_rules) + + if not has_special_string: + model_rules[0] += r'"\n" ws "}"' + + all_rules.extend(model_rules) + + all_rules.insert(0, root_rule + model_rule + grammar_model_rules) + return "\n".join(all_rules) + + +def get_primitive_grammar(grammar): + """ + Returns the needed GBNF primitive grammar for a given GBNF grammar string. + + Args: + grammar (str): The string containing the GBNF grammar. + + Returns: + str: GBNF primitive grammar string. + """ + type_list: list[type[object]] = [] + if "string-list" in grammar: + type_list.append(str) + if "boolean-list" in grammar: + type_list.append(bool) + if "integer-list" in grammar: + type_list.append(int) + if "float-list" in grammar: + type_list.append(float) + additional_grammar = [generate_list_rule(t) for t in type_list] + primitive_grammar = r""" +boolean ::= "true" | "false" +null ::= "null" +string ::= "\"" ( + [^"\\] | + "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) + )* "\"" ws +ws ::= ([ \t\n] ws)? +float ::= ("-"? ([0] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws + +integer ::= [0-9]+""" + + any_block = "" + if "custom-class-any" in grammar: + any_block = """ +value ::= object | array | string | number | boolean | null + +object ::= + "{" ws ( + string ":" ws value + ("," ws string ":" ws value)* + )? "}" ws + +array ::= + "[" ws ( + value + ("," ws value)* + )? "]" ws + +number ::= integer | float""" + + markdown_code_block_grammar = "" + if "markdown-code-block" in grammar: + markdown_code_block_grammar = r''' +markdown-code-block ::= opening-triple-ticks markdown-code-block-content closing-triple-ticks +markdown-code-block-content ::= ( [^`] | "`" [^`] | "`" "`" [^`] )* +opening-triple-ticks ::= "```" "python" "\n" | "```" "c" "\n" | "```" "cpp" "\n" | "```" "txt" "\n" | "```" "text" "\n" | "```" "json" "\n" | "```" "javascript" "\n" | "```" "css" "\n" | "```" "html" "\n" | "```" "markdown" "\n" +closing-triple-ticks ::= "```" "\n"''' + + if "triple-quoted-string" in grammar: + markdown_code_block_grammar = r""" +triple-quoted-string ::= triple-quotes triple-quoted-string-content triple-quotes +triple-quoted-string-content ::= ( [^'] | "'" [^'] | "'" "'" [^'] )* +triple-quotes ::= "'''" """ + return "\n" + "\n".join(additional_grammar) + any_block + primitive_grammar + markdown_code_block_grammar + + +def generate_markdown_documentation( + pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields", + documentation_with_field_description=True +) -> str: + """ + Generate markdown documentation for a list of Pydantic models. + + Args: + pydantic_models (list[type[BaseModel]]): list of Pydantic model classes. + model_prefix (str): Prefix for the model section. + fields_prefix (str): Prefix for the fields section. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + str: Generated text documentation. + """ + documentation = "" + pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models] + for model, add_prefix in pyd_models: + if add_prefix: + documentation += f"{model_prefix}: {model.__name__}\n" + else: + documentation += f"Model: {model.__name__}\n" + + # Handling multi-line model description with proper indentation + + class_doc = getdoc(model) + base_class_doc = getdoc(BaseModel) + class_description = class_doc if class_doc and class_doc != base_class_doc else "" + if class_description != "": + documentation += " Description: " + documentation += format_multiline_description(class_description, 0) + "\n" + + if add_prefix: + # Indenting the fields section + documentation += f" {fields_prefix}:\n" + else: + documentation += f" Fields:\n" # noqa: F541 + if isclass(model) and issubclass(model, BaseModel): + for name, field_type in get_type_hints(model).items(): + # if name == "markdown_code_block": + # continue + if get_origin(field_type) == list: + element_type = get_args(field_type)[0] + if isclass(element_type) and issubclass(element_type, BaseModel): + pyd_models.append((element_type, False)) + if get_origin(field_type) == Union: + element_types = get_args(field_type) + for element_type in element_types: + if isclass(element_type) and issubclass(element_type, BaseModel): + pyd_models.append((element_type, False)) + documentation += generate_field_markdown( + name, field_type, model, documentation_with_field_description=documentation_with_field_description + ) + documentation += "\n" + + if hasattr(model, "Config") and hasattr(model.Config, + "json_schema_extra") and "example" in model.Config.json_schema_extra: + documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n" + json_example = json.dumps(model.Config.json_schema_extra["example"]) + documentation += format_multiline_description(json_example, 2) + "\n" + + return documentation + + +def generate_field_markdown( + field_name: str, field_type: type[Any], model: type[BaseModel], depth=1, + documentation_with_field_description=True +) -> str: + """ + Generate markdown documentation for a Pydantic model field. + + Args: + field_name (str): Name of the field. + field_type (type[Any]): Type of the field. + model (type[BaseModel]): Pydantic model class. + depth (int): Indentation depth in the documentation. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + str: Generated text documentation for the field. + """ + indent = " " * depth + + field_info = model.model_fields.get(field_name) + field_description = field_info.description if field_info and field_info.description else "" + + origin_type = get_origin(field_type) + origin_type = field_type if origin_type is None else origin_type + + if origin_type == list: + element_type = get_args(field_type)[0] + field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" + if field_description != "": + field_text += ":\n" + else: + field_text += "\n" + elif origin_type == Union: + element_types = get_args(field_type) + types = [] + for element_type in element_types: + types.append(format_model_and_field_name(element_type.__name__)) + field_text = f"{indent}{field_name} ({' or '.join(types)})" + if field_description != "": + field_text += ":\n" + else: + field_text += "\n" + else: + field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + if field_description != "": + field_text += ":\n" + else: + field_text += "\n" + + if not documentation_with_field_description: + return field_text + + if field_description != "": + field_text += f" Description: {field_description}\n" + + # Check for and include field-specific examples if available + if hasattr(model, "Config") and hasattr(model.Config, + "json_schema_extra") and "example" in model.Config.json_schema_extra: + field_example = model.Config.json_schema_extra["example"].get(field_name) + if field_example is not None: + example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example + field_text += f"{indent} Example: {example_text}\n" + + if isclass(origin_type) and issubclass(origin_type, BaseModel): + field_text += f"{indent} Details:\n" + for name, type_ in get_type_hints(field_type).items(): + field_text += generate_field_markdown(name, type_, field_type, depth + 2) + + return field_text + + +def format_json_example(example: dict[str, Any], depth: int) -> str: + """ + Format a JSON example into a readable string with indentation. + + Args: + example (dict): JSON example to be formatted. + depth (int): Indentation depth. + + Returns: + str: Formatted JSON example string. + """ + indent = " " * depth + formatted_example = "{\n" + for key, value in example.items(): + value_text = f"'{value}'" if isinstance(value, str) else value + formatted_example += f"{indent}{key}: {value_text},\n" + formatted_example = formatted_example.rstrip(",\n") + "\n" + indent + "}" + return formatted_example + + +def generate_text_documentation( + pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields", + documentation_with_field_description=True +) -> str: + """ + Generate text documentation for a list of Pydantic models. + + Args: + pydantic_models (list[type[BaseModel]]): List of Pydantic model classes. + model_prefix (str): Prefix for the model section. + fields_prefix (str): Prefix for the fields section. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + str: Generated text documentation. + """ + documentation = "" + pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models] + for model, add_prefix in pyd_models: + if add_prefix: + documentation += f"{model_prefix}: {model.__name__}\n" + else: + documentation += f"Model: {model.__name__}\n" + + # Handling multi-line model description with proper indentation + + class_doc = getdoc(model) + base_class_doc = getdoc(BaseModel) + class_description = class_doc if class_doc and class_doc != base_class_doc else "" + if class_description != "": + documentation += " Description: " + documentation += "\n" + format_multiline_description(class_description, 2) + "\n" + + if isclass(model) and issubclass(model, BaseModel): + documentation_fields = "" + for name, field_type in get_type_hints(model).items(): + # if name == "markdown_code_block": + # continue + if get_origin(field_type) == list: + element_type = get_args(field_type)[0] + if isclass(element_type) and issubclass(element_type, BaseModel): + pyd_models.append((element_type, False)) + if get_origin(field_type) == Union: + element_types = get_args(field_type) + for element_type in element_types: + if isclass(element_type) and issubclass(element_type, BaseModel): + pyd_models.append((element_type, False)) + documentation_fields += generate_field_text( + name, field_type, model, documentation_with_field_description=documentation_with_field_description + ) + if documentation_fields != "": + if add_prefix: + documentation += f" {fields_prefix}:\n{documentation_fields}" + else: + documentation += f" Fields:\n{documentation_fields}" + documentation += "\n" + + if hasattr(model, "Config") and hasattr(model.Config, + "json_schema_extra") and "example" in model.Config.json_schema_extra: + documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n" + json_example = json.dumps(model.Config.json_schema_extra["example"]) + documentation += format_multiline_description(json_example, 2) + "\n" + + return documentation + + +def generate_field_text( + field_name: str, field_type: type[Any], model: type[BaseModel], depth=1, + documentation_with_field_description=True +) -> str: + """ + Generate text documentation for a Pydantic model field. + + Args: + field_name (str): Name of the field. + field_type (type[Any]): Type of the field. + model (type[BaseModel]): Pydantic model class. + depth (int): Indentation depth in the documentation. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + str: Generated text documentation for the field. + """ + indent = " " * depth + + field_info = model.model_fields.get(field_name) + field_description = field_info.description if field_info and field_info.description else "" + + if get_origin(field_type) == list: + element_type = get_args(field_type)[0] + field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)} of {format_model_and_field_name(element_type.__name__)})" + if field_description != "": + field_text += ":\n" + else: + field_text += "\n" + elif get_origin(field_type) == Union: + element_types = get_args(field_type) + types = [] + for element_type in element_types: + types.append(format_model_and_field_name(element_type.__name__)) + field_text = f"{indent}{field_name} ({' or '.join(types)})" + if field_description != "": + field_text += ":\n" + else: + field_text += "\n" + else: + field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + if field_description != "": + field_text += ":\n" + else: + field_text += "\n" + + if not documentation_with_field_description: + return field_text + + if field_description != "": + field_text += f"{indent} Description: " + field_description + "\n" + + # Check for and include field-specific examples if available + if hasattr(model, "Config") and hasattr(model.Config, + "json_schema_extra") and "example" in model.Config.json_schema_extra: + field_example = model.Config.json_schema_extra["example"].get(field_name) + if field_example is not None: + example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example + field_text += f"{indent} Example: {example_text}\n" + + if isclass(field_type) and issubclass(field_type, BaseModel): + field_text += f"{indent} Details:\n" + for name, type_ in get_type_hints(field_type).items(): + field_text += generate_field_text(name, type_, field_type, depth + 2) + + return field_text + + +def format_multiline_description(description: str, indent_level: int) -> str: + """ + Format a multiline description with proper indentation. + + Args: + description (str): Multiline description. + indent_level (int): Indentation level. + + Returns: + str: Formatted multiline description. + """ + indent = " " * indent_level + return indent + description.replace("\n", "\n" + indent) + + +def save_gbnf_grammar_and_documentation( + grammar, documentation, grammar_file_path="./grammar.gbnf", documentation_file_path="./grammar_documentation.md" +): + """ + Save GBNF grammar and documentation to specified files. + + Args: + grammar (str): GBNF grammar string. + documentation (str): Documentation string. + grammar_file_path (str): File path to save the GBNF grammar. + documentation_file_path (str): File path to save the documentation. + + Returns: + None + """ + try: + with open(grammar_file_path, "w") as file: + file.write(grammar + get_primitive_grammar(grammar)) + print(f"Grammar successfully saved to {grammar_file_path}") + except IOError as e: + print(f"An error occurred while saving the grammar file: {e}") + + try: + with open(documentation_file_path, "w") as file: + file.write(documentation) + print(f"Documentation successfully saved to {documentation_file_path}") + except IOError as e: + print(f"An error occurred while saving the documentation file: {e}") + + +def remove_empty_lines(string): + """ + Remove empty lines from a string. + + Args: + string (str): Input string. + + Returns: + str: String with empty lines removed. + """ + lines = string.splitlines() + non_empty_lines = [line for line in lines if line.strip() != ""] + string_no_empty_lines = "\n".join(non_empty_lines) + return string_no_empty_lines + + +def generate_and_save_gbnf_grammar_and_documentation( + pydantic_model_list, + grammar_file_path="./generated_grammar.gbnf", + documentation_file_path="./generated_grammar_documentation.md", + outer_object_name: str | None = None, + outer_object_content: str | None = None, + model_prefix: str = "Output Model", + fields_prefix: str = "Output Fields", + list_of_outputs: bool = False, + documentation_with_field_description=True, +): + """ + Generate GBNF grammar and documentation, and save them to specified files. + + Args: + pydantic_model_list: List of Pydantic model classes. + grammar_file_path (str): File path to save the generated GBNF grammar. + documentation_file_path (str): File path to save the generated documentation. + outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling. + outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling. + model_prefix (str): Prefix for the model section in the documentation. + fields_prefix (str): Prefix for the fields section in the documentation. + list_of_outputs (bool): Whether the output is a list of items. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + None + """ + documentation = generate_markdown_documentation( + pydantic_model_list, model_prefix, fields_prefix, + documentation_with_field_description=documentation_with_field_description + ) + grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content, + list_of_outputs) + grammar = remove_empty_lines(grammar) + save_gbnf_grammar_and_documentation(grammar, documentation, grammar_file_path, documentation_file_path) + + +def generate_gbnf_grammar_and_documentation( + pydantic_model_list, + outer_object_name: str | None = None, + outer_object_content: str | None = None, + model_prefix: str = "Output Model", + fields_prefix: str = "Output Fields", + list_of_outputs: bool = False, + documentation_with_field_description=True, +): + """ + Generate GBNF grammar and documentation for a list of Pydantic models. + + Args: + pydantic_model_list: List of Pydantic model classes. + outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling. + outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling. + model_prefix (str): Prefix for the model section in the documentation. + fields_prefix (str): Prefix for the fields section in the documentation. + list_of_outputs (bool): Whether the output is a list of items. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + tuple: GBNF grammar string, documentation string. + """ + documentation = generate_markdown_documentation( + copy(pydantic_model_list), model_prefix, fields_prefix, + documentation_with_field_description=documentation_with_field_description + ) + grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content, + list_of_outputs) + grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar)) + return grammar, documentation + + +def generate_gbnf_grammar_and_documentation_from_dictionaries( + dictionaries: list[dict[str, Any]], + outer_object_name: str | None = None, + outer_object_content: str | None = None, + model_prefix: str = "Output Model", + fields_prefix: str = "Output Fields", + list_of_outputs: bool = False, + documentation_with_field_description=True, +): + """ + Generate GBNF grammar and documentation from a list of dictionaries. + + Args: + dictionaries (list[dict]): List of dictionaries representing Pydantic models. + outer_object_name (str): Outer object name for the GBNF grammar. If None, no outer object will be generated. Eg. "function" for function calling. + outer_object_content (str): Content for the outer rule in the GBNF grammar. Eg. "function_parameters" or "params" for function calling. + model_prefix (str): Prefix for the model section in the documentation. + fields_prefix (str): Prefix for the fields section in the documentation. + list_of_outputs (bool): Whether the output is a list of items. + documentation_with_field_description (bool): Include field descriptions in the documentation. + + Returns: + tuple: GBNF grammar string, documentation string. + """ + pydantic_model_list = create_dynamic_models_from_dictionaries(dictionaries) + documentation = generate_markdown_documentation( + copy(pydantic_model_list), model_prefix, fields_prefix, + documentation_with_field_description=documentation_with_field_description + ) + grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content, + list_of_outputs) + grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar)) + return grammar, documentation + + +def create_dynamic_model_from_function(func: Callable[..., Any]): + """ + Creates a dynamic Pydantic model from a given function's type hints and adds the function as a 'run' method. + + Args: + func (Callable): A function with type hints from which to create the model. + + Returns: + A dynamic Pydantic model class with the provided function as a 'run' method. + """ + + # Get the signature of the function + sig = inspect.signature(func) + + # Parse the docstring + assert func.__doc__ is not None + docstring = parse(func.__doc__) + + dynamic_fields = {} + param_docs = [] + for param in sig.parameters.values(): + # Exclude 'self' parameter + if param.name == "self": + continue + + # Assert that the parameter has a type annotation + if param.annotation == inspect.Parameter.empty: + raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation") + + # Find the parameter's description in the docstring + param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) + + # Assert that the parameter has a description + if not param_doc or not param_doc.description: + raise ValueError( + f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring") + + # Add parameter details to the schema + param_docs.append((param.name, param_doc)) + if param.default == inspect.Parameter.empty: + default_value = ... + else: + default_value = param.default + dynamic_fields[param.name] = ( + param.annotation if param.annotation != inspect.Parameter.empty else str, default_value) + # Creating the dynamic model + dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) + + for name, param_doc in param_docs: + dynamic_model.model_fields[name].description = param_doc.description + + dynamic_model.__doc__ = docstring.short_description + + def run_method_wrapper(self): + func_args = {name: getattr(self, name) for name, _ in dynamic_fields.items()} + return func(**func_args) + + # Adding the wrapped function as a 'run' method + setattr(dynamic_model, "run", run_method_wrapper) + return dynamic_model + + +def add_run_method_to_dynamic_model(model: type[BaseModel], func: Callable[..., Any]): + """ + Add a 'run' method to a dynamic Pydantic model, using the provided function. + + Args: + model (type[BaseModel]): Dynamic Pydantic model class. + func (Callable): Function to be added as a 'run' method to the model. + + Returns: + type[BaseModel]: Pydantic model class with the added 'run' method. + """ + + def run_method_wrapper(self): + func_args = {name: getattr(self, name) for name in model.model_fields} + return func(**func_args) + + # Adding the wrapped function as a 'run' method + setattr(model, "run", run_method_wrapper) + + return model + + +def create_dynamic_models_from_dictionaries(dictionaries: list[dict[str, Any]]): + """ + Create a list of dynamic Pydantic model classes from a list of dictionaries. + + Args: + dictionaries (list[dict]): List of dictionaries representing model structures. + + Returns: + list[type[BaseModel]]: List of generated dynamic Pydantic model classes. + """ + dynamic_models = [] + for func in dictionaries: + model_name = format_model_and_field_name(func.get("name", "")) + dyn_model = convert_dictionary_to_pydantic_model(func, model_name) + dynamic_models.append(dyn_model) + return dynamic_models + + +def map_grammar_names_to_pydantic_model_class(pydantic_model_list): + output = {} + for model in pydantic_model_list: + output[format_model_and_field_name(model.__name__)] = model + + return output + + +def json_schema_to_python_types(schema): + type_map = { + "any": Any, + "string": str, + "number": float, + "integer": int, + "boolean": bool, + "array": list, + } + return type_map[schema] + + +def list_to_enum(enum_name, values): + return Enum(enum_name, {value: value for value in values}) + + +def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: str = "CustomModel") -> type[Any]: + """ + Convert a dictionary to a Pydantic model class. + + Args: + dictionary (dict): Dictionary representing the model structure. + model_name (str): Name of the generated Pydantic model. + + Returns: + type[BaseModel]: Generated Pydantic model class. + """ + fields: dict[str, Any] = {} + + if "properties" in dictionary: + for field_name, field_data in dictionary.get("properties", {}).items(): + if field_data == "object": + submodel = convert_dictionary_to_pydantic_model(dictionary, f"{model_name}_{field_name}") + fields[field_name] = (submodel, ...) + else: + field_type = field_data.get("type", "str") + + if field_data.get("enum", []): + fields[field_name] = (list_to_enum(field_name, field_data.get("enum", [])), ...) + elif field_type == "array": + items = field_data.get("items", {}) + if items != {}: + array = {"properties": items} + array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items") + fields[field_name] = (List[array_type], ...) + else: + fields[field_name] = (list, ...) + elif field_type == "object": + submodel = convert_dictionary_to_pydantic_model(field_data, f"{model_name}_{field_name}") + fields[field_name] = (submodel, ...) + elif field_type == "required": + required = field_data.get("enum", []) + for key, field in fields.items(): + if key not in required: + optional_type = fields[key][0] + fields[key] = (Optional[optional_type], ...) + else: + field_type = json_schema_to_python_types(field_type) + fields[field_name] = (field_type, ...) + if "function" in dictionary: + for field_name, field_data in dictionary.get("function", {}).items(): + if field_name == "name": + model_name = field_data + elif field_name == "description": + fields["__doc__"] = field_data + elif field_name == "parameters": + return convert_dictionary_to_pydantic_model(field_data, f"{model_name}") + + if "parameters" in dictionary: + field_data = {"function": dictionary} + return convert_dictionary_to_pydantic_model(field_data, f"{model_name}") + if "required" in dictionary: + required = dictionary.get("required", []) + for key, field in fields.items(): + if key not in required: + optional_type = fields[key][0] + fields[key] = (Optional[optional_type], ...) + custom_model = create_model(model_name, **fields) + return custom_model diff --git a/llama.cpp/examples/pydantic_models_to_grammar_examples.py b/llama.cpp/examples/pydantic_models_to_grammar_examples.py new file mode 100755 index 0000000..6dadb7f --- /dev/null +++ b/llama.cpp/examples/pydantic_models_to_grammar_examples.py @@ -0,0 +1,312 @@ +#!/usr/bin/env python3 + +"""Function calling example using pydantic models.""" + +from __future__ import annotations + +import argparse +import datetime +import json +import logging +import textwrap +import sys +from enum import Enum +from typing import Optional, Union + +import requests +from pydantic import BaseModel, Field +from pydantic_models_to_grammar import (add_run_method_to_dynamic_model, convert_dictionary_to_pydantic_model, + create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation) + + +def create_completion(host, prompt, gbnf_grammar): + """Calls the /completion API on llama-server. + + See + https://github.com/ggml-org/llama.cpp/tree/HEAD/tools/server#api-endpoints + """ + print(f" Request:\n Grammar:\n{textwrap.indent(gbnf_grammar, ' ')}\n Prompt:\n{textwrap.indent(prompt.rstrip(), ' ')}") + headers = {"Content-Type": "application/json"} + data = {"prompt": prompt, "grammar": gbnf_grammar} + result = requests.post(f"http://{host}/completion", headers=headers, json=data).json() + assert data.get("error") is None, data + logging.info("Result: %s", result) + content = result["content"] + print(f" Model: {result['model']}") + print(f" Result:\n{textwrap.indent(json.dumps(json.loads(content), indent=2), ' ')}") + return content + + +# A function for the agent to send a message to the user. +class SendMessageToUser(BaseModel): + """Send a message to the User.""" + chain_of_thought: str = Field(..., description="Your chain of thought while sending the message.") + message: str = Field(..., description="Message you want to send to the user.") + + def run(self): + print(f"SendMessageToUser: {self.message}") + + +def example_rce(host): + """Minimal test case where the LLM call an arbitrary python function.""" + print("- example_rce") + tools = [SendMessageToUser] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools, outer_object_name="function", + outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") + system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + user_message = "What is 42 * 42?" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + tools_map = {tool.__name__:tool for tool in tools} + # This finds "SendMessageToUser": + tool = tools_map.get(json_data["function"]) + if not tool: + print(f"Error: unknown tool {json_data['function']}") + return 1 + tool(**json_data["function_parameters"]).run() + return 0 + + +# Enum for the calculator tool. +class MathOperation(Enum): + ADD = "add" + SUBTRACT = "subtract" + MULTIPLY = "multiply" + DIVIDE = "divide" + + +# Simple pydantic calculator tool for the agent that can add, subtract, +# multiply, and divide. Docstring and description of fields will be used in +# system prompt. +class Calculator(BaseModel): + """Perform a math operation on two numbers.""" + number_one: Union[int, float] = Field(..., description="First number.") + operation: MathOperation = Field(..., description="Math operation to perform.") + number_two: Union[int, float] = Field(..., description="Second number.") + + def run(self): + if self.operation == MathOperation.ADD: + return self.number_one + self.number_two + elif self.operation == MathOperation.SUBTRACT: + return self.number_one - self.number_two + elif self.operation == MathOperation.MULTIPLY: + return self.number_one * self.number_two + elif self.operation == MathOperation.DIVIDE: + return self.number_one / self.number_two + else: + raise ValueError("Unknown operation.") + + +def example_calculator(host): + """Have the LLM ask to get a calculation done. + + Here the grammar gets generated by passing the available function models to + generate_gbnf_grammar_and_documentation function. This also generates a + documentation usable by the LLM. + + pydantic_model_list is the list of pydantic models outer_object_name is an + optional name for an outer object around the actual model object. Like a + "function" object with "function_parameters" which contains the actual model + object. If None, no outer object will be generated outer_object_content is + the name of outer object content. + + model_prefix is the optional prefix for models in the documentation. (Default="Output Model") + fields_prefix is the prefix for the model fields in the documentation. (Default="Output Fields") + """ + print("- example_calculator") + tools = [SendMessageToUser, Calculator] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools, outer_object_name="function", + outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") + system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + user_message1 = "What is 42 * 42?" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message1}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + expected = { + "function": "Calculator", + "function_parameters": { + "number_one": 42, + "operation": "multiply", + "number_two": 42 + } + } + if json_data != expected: + print(" Result is not as expected!") + tools_map = {tool.__name__:tool for tool in tools} + # This finds "Calculator": + tool = tools_map.get(json_data["function"]) + if not tool: + print(f"Error: unknown tool {json_data['function']}") + return 1 + result = tool(**json_data["function_parameters"]).run() + print(f" Call {json_data['function']} gave result {result}") + return 0 + + +class Category(Enum): + """The category of the book.""" + Fiction = "Fiction" + NonFiction = "Non-Fiction" + + +class Book(BaseModel): + """Represents an entry about a book.""" + title: str = Field(..., description="Title of the book.") + author: str = Field(..., description="Author of the book.") + published_year: Optional[int] = Field(..., description="Publishing year of the book.") + keywords: list[str] = Field(..., description="A list of keywords.") + category: Category = Field(..., description="Category of the book.") + summary: str = Field(..., description="Summary of the book.") + + +def example_struct(host): + """A example structured output based on pydantic models. + + The LLM will create an entry for a Book database out of an unstructured + text. We need no additional parameters other than our list of pydantic + models. + """ + print("- example_struct") + tools = [Book] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation(pydantic_model_list=tools) + system_message = "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + documentation + text = """The Feynman Lectures on Physics is a physics textbook based on some lectures by Richard Feynman, a Nobel laureate who has sometimes been called "The Great Explainer". The lectures were presented before undergraduate students at the California Institute of Technology (Caltech), during 1961–1963. The book's co-authors are Feynman, Robert B. Leighton, and Matthew Sands.""" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + # In this case, there's no function nor function_parameters. + # Here the result will vary based on the LLM used. + keys = sorted(["title", "author", "published_year", "keywords", "category", "summary"]) + if keys != sorted(json_data.keys()): + print(f"Unexpected result: {sorted(json_data.keys())}") + return 1 + book = Book(**json_data) + print(f" As a Book object: %s" % book) + return 0 + + +def get_current_datetime(output_format: Optional[str] = None): + """Get the current date and time in the given format. + + Args: + output_format: formatting string for the date and time, defaults to '%Y-%m-%d %H:%M:%S' + """ + return datetime.datetime.now().strftime(output_format or "%Y-%m-%d %H:%M:%S") + + +# Example function to get the weather. +def get_current_weather(location, unit): + """Get the current weather in a given location""" + if "London" in location: + return json.dumps({"location": "London", "temperature": "42", "unit": unit.value}) + elif "New York" in location: + return json.dumps({"location": "New York", "temperature": "24", "unit": unit.value}) + elif "North Pole" in location: + return json.dumps({"location": "North Pole", "temperature": "-42", "unit": unit.value}) + return json.dumps({"location": location, "temperature": "unknown"}) + + +def example_concurrent(host): + """An example for parallel function calling with a Python function, a pydantic + function model and an OpenAI like function definition. + """ + print("- example_concurrent") + # Function definition in OpenAI style. + current_weather_tool = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + # Convert OpenAI function definition into pydantic model. + current_weather_tool_model = convert_dictionary_to_pydantic_model(current_weather_tool) + # Add the actual function to a pydantic model. + current_weather_tool_model = add_run_method_to_dynamic_model(current_weather_tool_model, get_current_weather) + + # Convert normal Python function to a pydantic model. + current_datetime_model = create_dynamic_model_from_function(get_current_datetime) + + tools = [SendMessageToUser, Calculator, current_datetime_model, current_weather_tool_model] + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools, outer_object_name="function", + outer_object_content="params", model_prefix="Function", fields_prefix="Parameters", list_of_outputs=True) + system_message = "You are an advanced AI assistant. You are interacting with the user and with your environment by calling functions. You call functions by writing JSON objects, which represent specific function calls.\nBelow is a list of your available function calls:\n\n" + documentation + text = """Get the date and time, get the current weather in celsius in London and solve the following calculation: 42 * 42""" + prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" + text = create_completion(host, prompt, gbnf_grammar) + json_data = json.loads(text) + expected = [ + { + "function": "get_current_datetime", + "params": { + "output_format": "%Y-%m-%d %H:%M:%S" + } + }, + { + "function": "get_current_weather", + "params": { + "location": "London", + "unit": "celsius" + } + }, + { + "function": "Calculator", + "params": { + "number_one": 42, + "operation": "multiply", + "number_two": 42 + } + } + ] + res = 0 + if json_data != expected: + print(" Result is not as expected!") + print(" This can happen on highly quantized models") + res = 1 + tools_map = {tool.__name__:tool for tool in tools} + for call in json_data: + tool = tools_map.get(call["function"]) + if not tool: + print(f"Error: unknown tool {call['function']}") + return 1 + result = tool(**call["params"]).run() + print(f" Call {call['function']} returned {result}") + # Should output something like this: + # Call get_current_datetime returned 2024-07-15 09:50:38 + # Call get_current_weather returned {"location": "London", "temperature": "42", "unit": "celsius"} + # Call Calculator returned 1764 + return res + + +def main(): + parser = argparse.ArgumentParser(description=sys.modules[__name__].__doc__) + parser.add_argument("--host", default="localhost:8080", help="llama.cpp server") + parser.add_argument("-v", "--verbose", action="store_true", help="enables logging") + args = parser.parse_args() + logging.basicConfig(level=logging.INFO if args.verbose else logging.ERROR) + ret = 0 + # Comment out below to only run the example you want. + ret = ret or example_rce(args.host) + ret = ret or example_calculator(args.host) + ret = ret or example_struct(args.host) + ret = ret or example_concurrent(args.host) + return ret + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/llama.cpp/examples/reason-act.sh b/llama.cpp/examples/reason-act.sh new file mode 100755 index 0000000..3c80192 --- /dev/null +++ b/llama.cpp/examples/reason-act.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash + +cd `dirname $0` +cd .. + +# get -m model parameter otherwise defer to default +if [ "$1" == "-m" ]; then + MODEL="-m $2 " +fi + +./llama-cli $MODEL --color \ + -f ./prompts/reason-act.txt \ + -i --interactive-first \ + --top_k 10000 --temp 0.2 --repeat_penalty 1 -t 7 -c 2048 \ + -r "Question:" -r "Observation:" --in-prefix " " \ + -n -1 diff --git a/llama.cpp/examples/regex_to_grammar.py b/llama.cpp/examples/regex_to_grammar.py new file mode 100644 index 0000000..5cd9210 --- /dev/null +++ b/llama.cpp/examples/regex_to_grammar.py @@ -0,0 +1,20 @@ +import json, subprocess, sys, os + +assert len(sys.argv) >= 2 +[_, pattern, *rest] = sys.argv + +print(subprocess.check_output( + [ + "python", + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "json_schema_to_grammar.py"), + *rest, + "-", + "--raw-pattern", + ], + text=True, + input=json.dumps({ + "type": "string", + "pattern": pattern, + }, indent=2))) diff --git a/llama.cpp/examples/retrieval/CMakeLists.txt b/llama.cpp/examples/retrieval/CMakeLists.txt new file mode 100644 index 0000000..512a602 --- /dev/null +++ b/llama.cpp/examples/retrieval/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-retrieval) +add_executable(${TARGET} retrieval.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/retrieval/README.md b/llama.cpp/examples/retrieval/README.md new file mode 100644 index 0000000..51038cc --- /dev/null +++ b/llama.cpp/examples/retrieval/README.md @@ -0,0 +1,69 @@ +# llama.cpp/examples/retrieval + +Demonstration of simple retrieval technique based on cosine similarity + +More info: +https://github.com/ggml-org/llama.cpp/pull/6193 + +### How to use + +`retieval.cpp` has parameters of its own: +- `--context-file`: file to be embedded - state this option multiple times to embed multiple files +- `--chunk-size`: minimum size of each text chunk to be embedded +- `--chunk-separator`: STRING to divide chunks by. newline by default + +`retrieval` example can be tested as follows: + +```bash +llama-retrieval --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator . +``` + +This chunks and embeds all given files and starts a loop requesting query inputs: + +``` +Enter query: +``` + +On each query input, top k chunks are shown along with file name, chunk position within file and original text: + +``` +Enter query: describe the mit license +batch_decode: n_tokens = 6, n_seq = 1 +Top 3 similar chunks: +filename: README.md +filepos: 119 +similarity: 0.762334 +textdata: +png) + +[](https://opensource.org/licenses/MIT) + +[Roadmap](https://github. +-------------------- +filename: License +filepos: 0 +similarity: 0.725146 +textdata: +MIT License + +Copyright (c) 2023 Georgi Gerganov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +-------------------- +filename: README.md +filepos: 9178 +similarity: 0.621722 +textdata: +com/cztomsik/ava) (MIT) +- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal) +- [pythops/tenere](https://github. +-------------------- +``` diff --git a/llama.cpp/examples/retrieval/retrieval.cpp b/llama.cpp/examples/retrieval/retrieval.cpp new file mode 100644 index 0000000..3f2afd4 --- /dev/null +++ b/llama.cpp/examples/retrieval/retrieval.cpp @@ -0,0 +1,304 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include <algorithm> +#include <fstream> +#include <iostream> // TODO: remove me + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s --model ./models/bge-base-en-v1.5-f16.gguf --top-k 3 --context-file README.md --context-file License --chunk-size 100 --chunk-separator .\n", argv[0]); + LOG("\n"); +} + +struct chunk { + // filename + std::string filename; + // original file position + size_t filepos; + // original text data + std::string textdata; + // tokenized text data + std::vector<llama_token> tokens; + // embedding + std::vector<float> embedding; +}; + +// chunk file data to chunks of size >= chunk_size +// chunk_separator is the separator between chunks +static std::vector<chunk> chunk_file(const std::string & filename, int chunk_size, const std::string & chunk_separator) { + std::vector<chunk> chunks; + std::ifstream f(filename.c_str()); + + if (!f.is_open()) { + LOG_ERR("could not open file %s\n", filename.c_str()); + return chunks; + } + + chunk current_chunk; + char buffer[1024]; + int64_t filepos = 0; + std::string current; + while (f.read(buffer, 1024)) { + current += std::string(buffer, f.gcount()); + size_t pos; + while ((pos = current.find(chunk_separator)) != std::string::npos) { + current_chunk.textdata += current.substr(0, pos + chunk_separator.size()); + if ((int) current_chunk.textdata.size() > chunk_size) { + // save chunk + current_chunk.filepos = filepos; + current_chunk.filename = filename; + chunks.push_back(current_chunk); + // update filepos + filepos += (int) current_chunk.textdata.size(); + // reset current_chunk + current_chunk = chunk(); + } + current = current.substr(pos + chunk_separator.size()); + } + + } + // add leftover data to last chunk + if (current_chunk.textdata.size() > 0) { + if (chunks.empty()) { + current_chunk.filepos = filepos; + current_chunk.filename = filename; + chunks.push_back(current_chunk); + } else { + chunks.back().textdata += current_chunk.textdata; + } + } + f.close(); + return chunks; +} + +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } +} + +static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { + // clear previous kv_cache values (irrelevant for embeddings) + llama_memory_clear(llama_get_memory(ctx), false); + + // run model + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + if (llama_decode(ctx, batch) < 0) { + LOG_ERR("%s : failed to process\n", __func__); + } + + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + // try to get sequence embeddings - supported only when pooling_type is not NONE + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); + continue; + } + } + + float * out = output + batch.seq_id[i][0] * n_embd; + common_embd_normalize(embd, out, n_embd, 2); + } +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) { + return 1; + } + + common_init(); + + // For BERT models, batch size must be equal to ubatch size + params.n_ubatch = params.n_batch; + params.embedding = true; + + if (params.chunk_size <= 0) { + LOG_ERR("chunk_size must be positive\n"); + return 1; + } + if (params.context_files.empty()) { + LOG_ERR("context_files must be specified\n"); + return 1; + } + + LOG_INF("processing files:\n"); + for (auto & context_file : params.context_files) { + LOG_INF("%s\n", context_file.c_str()); + } + + std::vector<chunk> chunks; + for (auto & context_file : params.context_files) { + std::vector<chunk> file_chunk = chunk_file(context_file, params.chunk_size, params.chunk_separator); + chunks.insert(chunks.end(), file_chunk.begin(), file_chunk.end()); + } + LOG_INF("Number of chunks: %zu\n", chunks.size()); + + llama_backend_init(); + llama_numa_init(params.numa); + + // load the model + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + const int n_ctx_train = llama_model_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + LOG_ERR("%s: pooling type NONE not supported\n", __func__); + return 1; + } + + if (n_ctx > n_ctx_train) { + LOG_WRN("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + // max batch size + const uint64_t n_batch = params.n_batch; + GGML_ASSERT(params.n_batch >= params.n_ctx); + + // tokenize the prompts and trim + for (auto & chunk : chunks) { + auto inp = common_tokenize(ctx, chunk.textdata, true, false); + if (inp.size() > n_batch) { + LOG_ERR("%s: chunk size (%lld) exceeds batch size (%lld), increase batch size and re-run\n", + __func__, (long long int) inp.size(), (long long int) n_batch); + return 1; + } + // add eos if not present + if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) { + inp.push_back(llama_vocab_eos(vocab)); + } + chunk.tokens = inp; + } + + // tokenization stats + if (params.verbose_prompt) { + for (int i = 0; i < (int) chunks.size(); i++) { + LOG_INF("%s: prompt %d: '%s'\n", __func__, i, chunks[i].textdata.c_str()); + LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, chunks[i].tokens.size()); + for (int j = 0; j < (int) chunks[i].tokens.size(); j++) { + LOG_INF("%6d -> '%s'\n", chunks[i].tokens[j], common_token_to_piece(ctx, chunks[i].tokens[j]).c_str()); + } + LOG_INF("\n\n"); + } + } + + // initialize batch + const int n_chunks = chunks.size(); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + + // allocate output + const int n_embd_out = llama_model_n_embd_out(model); + std::vector<float> embeddings(n_chunks * n_embd_out, 0); + float * emb = embeddings.data(); + + // break into batches + unsigned int p = 0; // number of prompts processed already + unsigned int s = 0; // number of prompts in current batch + for (int k = 0; k < n_chunks; k++) { + // clamp to n_batch tokens + auto & inp = chunks[k].tokens; + + const uint64_t n_toks = inp.size(); + + // encode if at capacity + if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) { + float * out = emb + p * n_embd_out; + batch_process(ctx, batch, out, s, n_embd_out); + common_batch_clear(batch); + p += s; + s = 0; + } + + // add to batch + batch_add_seq(batch, inp, s); + s += 1; + } + + // final batch + float * out = emb + p * n_embd_out; + batch_process(ctx, batch, out, s, n_embd_out); + + // save embeddings to chunks + for (int i = 0; i < n_chunks; i++) { + chunks[i].embedding = std::vector<float>(emb + i * n_embd_out, emb + (i + 1) * n_embd_out); + // clear tokens as they are no longer needed + chunks[i].tokens.clear(); + } + + struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); + + // start loop, receive query and return top k similar chunks based on cosine similarity + std::string query; + while (true) { + LOG("Enter query: "); + std::getline(std::cin, query); + std::vector<int32_t> query_tokens = common_tokenize(ctx, query, true); + + batch_add_seq(query_batch, query_tokens, 0); + + std::vector<float> query_emb(n_embd_out, 0); + batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out); + + common_batch_clear(query_batch); + + // compute cosine similarities + { + std::vector<std::pair<int, float>> similarities; + for (int i = 0; i < n_chunks; i++) { + float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out); + similarities.push_back(std::make_pair(i, sim)); + } + + // sort similarities + std::sort(similarities.begin(), similarities.end(), [](const std::pair<int, float> & a, const std::pair<int, float> & b) { + return a.second > b.second; + }); + + LOG("Top %d similar chunks:\n", params.sampling.top_k); + for (int i = 0; i < std::min(params.sampling.top_k, (int) chunks.size()); i++) { + LOG("filename: %s\n", chunks[similarities[i].first].filename.c_str()); + LOG("filepos: %lld\n", (long long int) chunks[similarities[i].first].filepos); + LOG("similarity: %f\n", similarities[i].second); + LOG("textdata:\n%s\n", chunks[similarities[i].first].textdata.c_str()); + LOG("--------------------\n"); + } + } + } + + LOG("\n"); + llama_perf_context_print(ctx); + + // clean up + llama_batch_free(query_batch); + llama_backend_free(); +} diff --git a/llama.cpp/examples/save-load-state/CMakeLists.txt b/llama.cpp/examples/save-load-state/CMakeLists.txt new file mode 100644 index 0000000..0f50e50 --- /dev/null +++ b/llama.cpp/examples/save-load-state/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-save-load-state) +add_executable(${TARGET} save-load-state.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/save-load-state/save-load-state.cpp b/llama.cpp/examples/save-load-state/save-load-state.cpp new file mode 100644 index 0000000..39d4464 --- /dev/null +++ b/llama.cpp/examples/save-load-state/save-load-state.cpp @@ -0,0 +1,258 @@ +#include "arg.h" +#include "common.h" +#include "llama.h" + +#include <vector> +#include <cstdio> + +int main(int argc, char ** argv) { + common_params params; + + params.prompt = "The quick brown fox"; + params.sampling.seed = 1234; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + if (params.n_parallel == 1) { + // the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache + printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__); + params.kv_unified = true; + } + + common_init(); + + if (params.n_predict < 0) { + params.n_predict = 16; + } + + auto n_past = 0; + + std::string result0; + std::string result1; + std::string result2; + + // init + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + fprintf(stderr, "%s : failed to init\n", __func__); + return 1; + } + + auto sparams = llama_sampler_chain_default_params(); + + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed)); + + // tokenize prompt + auto tokens = common_tokenize(ctx, params.prompt, true); + + // prepare the batch + llama_batch batch = llama_batch_init(tokens.size(), 0, 1); + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], i, {0}, false); + } + batch.logits[batch.n_tokens - 1] = true; // generate next token + + // evaluate prompt + llama_decode(ctx, batch); + n_past += batch.n_tokens; + + // save state (rng, logits, embedding and kv_cache) to file + { + std::vector<uint8_t> state_mem(llama_state_get_size(ctx)); + const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); + + FILE *fp_write = fopen("dump_state.bin", "wb"); + fwrite(state_mem.data(), 1, written, fp_write); + fclose(fp_write); + + fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); + } + + // save state (last tokens) + const auto n_past_saved = n_past; + + // first run + printf("\nfirst run: %s", params.prompt.c_str()); + + for (auto i = 0; i < params.n_predict; i++) { + auto next_token = llama_sampler_sample(smpl, ctx, -1); + auto next_token_str = common_token_to_piece(ctx, next_token); + + printf("%s", next_token_str.c_str()); + result0 += next_token_str; + + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx, batch)) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); + return 1; + } + n_past += 1; + } + + printf("\n\n"); + + // make new context + llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params)); + + llama_sampler * smpl2 = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sampling.seed)); + + printf("\nsecond run: %s", params.prompt.c_str()); + + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector<uint8_t> state_mem; + + FILE * fp_read = fopen("dump_state.bin", "rb"); + fseek(fp_read, 0, SEEK_END); + state_mem.resize(ftell(fp_read)); + fseek(fp_read, 0, SEEK_SET); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + return 1; + } + + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + } + + // restore state (last tokens) + n_past = n_past_saved; + + // second run + for (auto i = 0; i < params.n_predict; i++) { + auto next_token = llama_sampler_sample(smpl2, ctx2, -1); + auto next_token_str = common_token_to_piece(ctx2, next_token); + + printf("%s", next_token_str.c_str()); + result1 += next_token_str; + + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {0}, true); + + if (llama_decode(ctx2, batch)) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); + return 1; + } + n_past += 1; + } + + printf("\n\n"); + + if (result0 != result1) { + fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); + return 1; + } + + // make new context + llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); + + llama_sampler * smpl3 = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sampling.seed)); + + printf("\nsingle seq run: %s", params.prompt.c_str()); + + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector<uint8_t> state_mem; + + FILE * fp_read = fopen("dump_state.bin", "rb"); + fseek(fp_read, 0, SEEK_END); + state_mem.resize(ftell(fp_read)); + fseek(fp_read, 0, SEEK_SET); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + return 1; + } + + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + } + + // restore state (last tokens) + n_past = n_past_saved; + + // save seq 0 and load into seq 1 + { + // save kv of seq 0 + std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0)); + const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0); + if (ncopy != seq_store.size()) { + fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); + return 1; + } + fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); + + // erase whole kv + llama_memory_clear(llama_get_memory(ctx3), true); + fprintf(stderr, "%s : kv cache cleared\n", __func__); + + // restore kv into seq 1 + const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1); + if (nset != seq_store.size()) { + fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); + return 1; + } + fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); + } + + // third run with seq 1 instead of 0 + for (auto i = 0; i < params.n_predict; i++) { + auto next_token = llama_sampler_sample(smpl3, ctx3, -1); + auto next_token_str = common_token_to_piece(ctx3, next_token); + + printf("%s", next_token_str.c_str()); + result2 += next_token_str; + + common_batch_clear(batch); + common_batch_add(batch, next_token, n_past, {1}, true); + + if (llama_decode(ctx3, batch)) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_batch_free(batch); + return 1; + } + n_past += 1; + } + + printf("\n"); + + llama_sampler_free(smpl); + llama_sampler_free(smpl2); + llama_sampler_free(smpl3); + + llama_batch_free(batch); + + // this one is managed by common_init_result + //llama_free(ctx); + + llama_free(ctx2); + llama_free(ctx3); + + if (result0 != result2) { + fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); + return 1; + } + + fprintf(stderr, "\n%s : success\n", __func__); + + return 0; +} diff --git a/llama.cpp/examples/server-llama2-13B.sh b/llama.cpp/examples/server-llama2-13B.sh new file mode 100755 index 0000000..fd5a575 --- /dev/null +++ b/llama.cpp/examples/server-llama2-13B.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash + +set -e + +cd "$(dirname "$0")/.." || exit + +# Specify the model you want to use here: +MODEL="${MODEL:-./models/llama-2-13b-chat.ggmlv3.q5_K_M.bin}" +PROMPT_TEMPLATE=${PROMPT_TEMPLATE:-./prompts/chat-system.txt} + +# Adjust to the number of CPU cores you want to use. +N_THREAD="${N_THREAD:-12}" + +# Note: you can also override the generation options by specifying them on the command line: +GEN_OPTIONS="${GEN_OPTIONS:---ctx_size 4096 --batch-size 1024}" + + +# shellcheck disable=SC2086 # Intended splitting of GEN_OPTIONS +./llama-server $GEN_OPTIONS \ + --model "$MODEL" \ + --threads "$N_THREAD" \ + --rope-freq-scale 1.0 \ + "$@" + +# I used this to test the model with mps, but omitted it from the general purpose. If you want to use it, just specify it on the command line. +# -ngl 1 \ diff --git a/llama.cpp/examples/server_embd.py b/llama.cpp/examples/server_embd.py new file mode 100644 index 0000000..f8b0ffe --- /dev/null +++ b/llama.cpp/examples/server_embd.py @@ -0,0 +1,35 @@ +import asyncio +import asyncio.threads +import requests +import numpy as np + + +n = 8 + +result = [] + +async def requests_post_async(*args, **kwargs): + return await asyncio.threads.to_thread(requests.post, *args, **kwargs) + +async def main(): + model_url = "http://127.0.0.1:6900" + responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( + url= f"{model_url}/embedding", + json= {"content": "a "*1022} + ) for i in range(n)]) + + for response in responses: + embedding = response.json()["embedding"] + print(embedding[-8:]) + result.append(embedding) + +asyncio.run(main()) + +# compute cosine similarity + +for i in range(n-1): + for j in range(i+1, n): + embedding1 = np.array(result[i]) + embedding2 = np.array(result[j]) + similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) + print(f"Similarity between {i} and {j}: {similarity:.2f}") diff --git a/llama.cpp/examples/simple-chat/CMakeLists.txt b/llama.cpp/examples/simple-chat/CMakeLists.txt new file mode 100644 index 0000000..567f7fb --- /dev/null +++ b/llama.cpp/examples/simple-chat/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-simple-chat) +add_executable(${TARGET} simple-chat.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/simple-chat/README.md b/llama.cpp/examples/simple-chat/README.md new file mode 100644 index 0000000..f0099ce --- /dev/null +++ b/llama.cpp/examples/simple-chat/README.md @@ -0,0 +1,7 @@ +# llama.cpp/example/simple-chat + +The purpose of this example is to demonstrate a minimal usage of llama.cpp to create a simple chat program using the chat template from the GGUF file. + +```bash +./llama-simple-chat -m Meta-Llama-3.1-8B-Instruct.gguf -c 2048 +... diff --git a/llama.cpp/examples/simple-chat/simple-chat.cpp b/llama.cpp/examples/simple-chat/simple-chat.cpp new file mode 100644 index 0000000..57195df --- /dev/null +++ b/llama.cpp/examples/simple-chat/simple-chat.cpp @@ -0,0 +1,207 @@ +#include "llama.h" +#include <cstdio> +#include <cstring> +#include <iostream> +#include <string> +#include <vector> + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-c context_size] [-ngl n_gpu_layers]\n", argv[0]); + printf("\n"); +} + +int main(int argc, char ** argv) { + std::string model_path; + int ngl = 99; + int n_ctx = 2048; + + // parse command line arguments + for (int i = 1; i < argc; i++) { + try { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { + model_path = argv[++i]; + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-c") == 0) { + if (i + 1 < argc) { + n_ctx = std::stoi(argv[++i]); + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + ngl = std::stoi(argv[++i]); + } else { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } catch (std::exception & e) { + fprintf(stderr, "error: %s\n", e.what()); + print_usage(argc, argv); + return 1; + } + } + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + + // only print errors + llama_log_set([](enum ggml_log_level level, const char * text, void * /* user_data */) { + if (level >= GGML_LOG_LEVEL_ERROR) { + fprintf(stderr, "%s", text); + } + }, nullptr); + + // load dynamic backends + ggml_backend_load_all(); + + // initialize the model + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); + if (!model) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + + // initialize the context + llama_context_params ctx_params = llama_context_default_params(); + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = n_ctx; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + if (!ctx) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // initialize the sampler + llama_sampler * smpl = llama_sampler_chain_init(llama_sampler_chain_default_params()); + llama_sampler_chain_add(smpl, llama_sampler_init_min_p(0.05f, 1)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); + + // helper function to evaluate a prompt and generate a response + auto generate = [&](const std::string & prompt) { + std::string response; + + const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == -1; + + // tokenize the prompt + const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true); + std::vector<llama_token> prompt_tokens(n_prompt_tokens); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), is_first, true) < 0) { + GGML_ABORT("failed to tokenize the prompt\n"); + } + + // prepare a batch for the prompt + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_token new_token_id; + while (true) { + // check if we have enough space in the context to evaluate this batch + int n_ctx = llama_n_ctx(ctx); + int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) + 1; + if (n_ctx_used + batch.n_tokens > n_ctx) { + printf("\033[0m\n"); + fprintf(stderr, "context size exceeded\n"); + exit(0); + } + + int ret = llama_decode(ctx, batch); + if (ret != 0) { + GGML_ABORT("failed to decode, ret = %d\n", ret); + } + + // sample the next token + new_token_id = llama_sampler_sample(smpl, ctx, -1); + + // is it an end of generation? + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + + // convert the token to a string, print it and add it to the response + char buf[256]; + int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + GGML_ABORT("failed to convert token to piece\n"); + } + std::string piece(buf, n); + printf("%s", piece.c_str()); + fflush(stdout); + response += piece; + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + } + + return response; + }; + + std::vector<llama_chat_message> messages; + std::vector<char> formatted(llama_n_ctx(ctx)); + int prev_len = 0; + while (true) { + // get user input + printf("\033[32m> \033[0m"); + std::string user; + std::getline(std::cin, user); + + if (user.empty()) { + break; + } + + const char * tmpl = llama_model_chat_template(model, /* name */ nullptr); + + // add the user input to the message list and format it + messages.push_back({"user", strdup(user.c_str())}); + int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + if (new_len > (int)formatted.size()) { + formatted.resize(new_len); + new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size()); + } + if (new_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + + // remove previous messages to obtain the prompt to generate the response + std::string prompt(formatted.begin() + prev_len, formatted.begin() + new_len); + + // generate a response + printf("\033[33m"); + std::string response = generate(prompt); + printf("\n\033[0m"); + + // add the response to the messages + messages.push_back({"assistant", strdup(response.c_str())}); + prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0); + if (prev_len < 0) { + fprintf(stderr, "failed to apply the chat template\n"); + return 1; + } + } + + // free resources + for (auto & msg : messages) { + free(const_cast<char *>(msg.content)); + } + llama_sampler_free(smpl); + llama_free(ctx); + llama_model_free(model); + + return 0; +} diff --git a/llama.cpp/examples/simple-cmake-pkg/.gitignore b/llama.cpp/examples/simple-cmake-pkg/.gitignore new file mode 100644 index 0000000..67c01d6 --- /dev/null +++ b/llama.cpp/examples/simple-cmake-pkg/.gitignore @@ -0,0 +1,50 @@ +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +*.gguf + +*.log +.DS_Store +.build/ +.cache/ +.direnv/ +.envrc +.swiftpm +.venv +.clang-tidy +.vs/ +.vscode/ + +build*/ +out/ +tmp/ diff --git a/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt b/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt new file mode 100644 index 0000000..128e38c --- /dev/null +++ b/llama.cpp/examples/simple-cmake-pkg/CMakeLists.txt @@ -0,0 +1,11 @@ +cmake_minimum_required(VERSION 3.12) +project(llama-simple-cmake-pkg) + +set(TARGET llama-simple-cmake-pkg) + +find_package(Llama REQUIRED) + +add_executable(${TARGET} ${CMAKE_CURRENT_LIST_DIR}/../simple/simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ggml::all ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/simple-cmake-pkg/README.md b/llama.cpp/examples/simple-cmake-pkg/README.md new file mode 100644 index 0000000..0cabc2e --- /dev/null +++ b/llama.cpp/examples/simple-cmake-pkg/README.md @@ -0,0 +1,35 @@ +# llama.cpp/example/simple-cmake-pkg + +This program builds [simple](../simple) using a relocatable CMake package. It serves as an example of using the `find_package()` CMake command to conveniently include [llama.cpp](https://github.com/ggml-org/llama.cpp) in projects which live outside of the source tree. + +## Building + +Because this example is "outside of the source tree", it is important to first build/install llama.cpp using CMake. An example is provided here, but please see the [llama.cpp build instructions](../..) for more detailed build instructions. + +### Considerations + +When hardware acceleration libraries are used (e.g. CUDA, Metal, Vulkan, etc.), the appropriate dependencies will be searched for automatically. So, for example, when finding a package + +### Build llama.cpp and install to llama.cpp/inst + +```sh +git clone https://github.com/ggml-org/llama.cpp +cd llama.cpp +cmake -S . -B build +cmake --build build +cmake --install build --prefix inst +``` + +### Build simple-cmake-pkg + +```sh +cd examples/simple-cmake-pkg +cmake -S . -B build -DCMAKE_PREFIX_PATH=../../inst/lib/cmake +cmake --build build +``` + +### Run simple-cmake-pkg + +```sh +./build/llama-simple-cmake-pkg -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" +``` diff --git a/llama.cpp/examples/simple/CMakeLists.txt b/llama.cpp/examples/simple/CMakeLists.txt new file mode 100644 index 0000000..104ecab --- /dev/null +++ b/llama.cpp/examples/simple/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-simple) +add_executable(${TARGET} simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/simple/README.md b/llama.cpp/examples/simple/README.md new file mode 100644 index 0000000..937008b --- /dev/null +++ b/llama.cpp/examples/simple/README.md @@ -0,0 +1,21 @@ +# llama.cpp/example/simple + +The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt. + +```bash +./llama-simple -m ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" + +... + +main: n_len = 32, n_ctx = 2048, n_parallel = 1, n_kv_req = 32 + + Hello my name is Shawn and I'm a 20 year old male from the United States. I'm a 20 year old + +main: decoded 27 tokens in 2.31 s, speed: 11.68 t/s + +llama_print_timings: load time = 579.15 ms +llama_print_timings: sample time = 0.72 ms / 28 runs ( 0.03 ms per token, 38888.89 tokens per second) +llama_print_timings: prompt eval time = 655.63 ms / 10 tokens ( 65.56 ms per token, 15.25 tokens per second) +llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second) +llama_print_timings: total time = 2891.13 ms +``` diff --git a/llama.cpp/examples/simple/simple.cpp b/llama.cpp/examples/simple/simple.cpp new file mode 100644 index 0000000..d09771d --- /dev/null +++ b/llama.cpp/examples/simple/simple.cpp @@ -0,0 +1,220 @@ +#include "llama.h" +#include <cstdio> +#include <cstring> +#include <string> +#include <vector> + +static void print_usage(int, char ** argv) { + printf("\nexample usage:\n"); + printf("\n %s -m model.gguf [-n n_predict] [-ngl n_gpu_layers] [prompt]\n", argv[0]); + printf("\n"); +} + +int main(int argc, char ** argv) { + // path to the model gguf file + std::string model_path; + // prompt to generate text from + std::string prompt = "Hello my name is"; + // number of layers to offload to the GPU + int ngl = 99; + // number of tokens to predict + int n_predict = 32; + + // parse command line arguments + + { + int i = 1; + for (; i < argc; i++) { + if (strcmp(argv[i], "-m") == 0) { + if (i + 1 < argc) { + model_path = argv[++i]; + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-n") == 0) { + if (i + 1 < argc) { + try { + n_predict = std::stoi(argv[++i]); + } catch (...) { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } else if (strcmp(argv[i], "-ngl") == 0) { + if (i + 1 < argc) { + try { + ngl = std::stoi(argv[++i]); + } catch (...) { + print_usage(argc, argv); + return 1; + } + } else { + print_usage(argc, argv); + return 1; + } + } else { + // prompt starts here + break; + } + } + if (model_path.empty()) { + print_usage(argc, argv); + return 1; + } + if (i < argc) { + prompt = argv[i++]; + for (; i < argc; i++) { + prompt += " "; + prompt += argv[i]; + } + } + } + + // load dynamic backends + + ggml_backend_load_all(); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + model_params.n_gpu_layers = ngl; + + llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + const llama_vocab * vocab = llama_model_get_vocab(model); + // tokenize the prompt + + // find the number of tokens in the prompt + const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); + + // allocate space for the tokens and tokenize the prompt + std::vector<llama_token> prompt_tokens(n_prompt); + if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { + fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); + return 1; + } + + // initialize the context + + llama_context_params ctx_params = llama_context_default_params(); + // n_ctx is the context size + ctx_params.n_ctx = n_prompt + n_predict - 1; + // n_batch is the maximum number of tokens that can be processed in a single call to llama_decode + ctx_params.n_batch = n_prompt; + // enable performance counters + ctx_params.no_perf = false; + + llama_context * ctx = llama_init_from_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + // initialize the sampler + + auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + + // print the prompt token-by-token + + for (auto id : prompt_tokens) { + char buf[128]; + int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); + if (n < 0) { + fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); + return 1; + } + std::string s(buf, n); + printf("%s", s.c_str()); + } + + // prepare a batch for the prompt + + llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + + if (llama_model_has_encoder(model)) { + if (llama_encode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return 1; + } + + llama_token decoder_start_token_id = llama_model_decoder_start_token(model); + if (decoder_start_token_id == LLAMA_TOKEN_NULL) { + decoder_start_token_id = llama_vocab_bos(vocab); + } + + batch = llama_batch_get_one(&decoder_start_token_id, 1); + } + + // main loop + + const auto t_main_start = ggml_time_us(); + int n_decode = 0; + llama_token new_token_id; + + for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } + + n_pos += batch.n_tokens; + + // sample the next token + { + new_token_id = llama_sampler_sample(smpl, ctx, -1); + + // is it an end of generation? + if (llama_vocab_is_eog(vocab, new_token_id)) { + break; + } + + char buf[128]; + int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true); + if (n < 0) { + fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); + return 1; + } + std::string s(buf, n); + printf("%s", s.c_str()); + fflush(stdout); + + // prepare the next batch with the sampled token + batch = llama_batch_get_one(&new_token_id, 1); + + n_decode += 1; + } + } + + printf("\n"); + + const auto t_main_end = ggml_time_us(); + + fprintf(stderr, "%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + fprintf(stderr, "\n"); + llama_perf_sampler_print(smpl); + llama_perf_context_print(ctx); + fprintf(stderr, "\n"); + + llama_sampler_free(smpl); + llama_free(ctx); + llama_model_free(model); + + return 0; +} diff --git a/llama.cpp/examples/speculative-simple/CMakeLists.txt b/llama.cpp/examples/speculative-simple/CMakeLists.txt new file mode 100644 index 0000000..aeaea74 --- /dev/null +++ b/llama.cpp/examples/speculative-simple/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-speculative-simple) +add_executable(${TARGET} speculative-simple.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/speculative-simple/README.md b/llama.cpp/examples/speculative-simple/README.md new file mode 100644 index 0000000..e3a6c6b --- /dev/null +++ b/llama.cpp/examples/speculative-simple/README.md @@ -0,0 +1,12 @@ +# llama.cpp/examples/speculative-simple + +Demonstration of basic greedy speculative decoding + +```bash +./bin/llama-speculative-simple \ + -m ../models/qwen2.5-32b-coder-instruct/ggml-model-q8_0.gguf \ + -md ../models/qwen2.5-1.5b-coder-instruct/ggml-model-q4_0.gguf \ + -f test.txt -c 0 -ngl 99 --color \ + --sampling-seq k --top-k 1 -fa --temp 0.0 \ + -ngld 99 --draft-max 16 --draft-min 5 --draft-p-min 0.9 +``` diff --git a/llama.cpp/examples/speculative-simple/speculative-simple.cpp b/llama.cpp/examples/speculative-simple/speculative-simple.cpp new file mode 100644 index 0000000..d8b1f5a --- /dev/null +++ b/llama.cpp/examples/speculative-simple/speculative-simple.cpp @@ -0,0 +1,266 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "speculative.h" +#include "log.h" +#include "llama.h" + +#include <cstdio> +#include <cstring> +#include <string> +#include <vector> + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { + return 1; + } + + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.speculative.mparams_dft.path.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); + return 1; + } + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_tgt = NULL; + + llama_context * ctx_tgt = NULL; + + // load the target model + auto llama_init_tgt = common_init_from_params(params); + + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); + + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); + + // load the draft model + llama_model_ptr model_dft; + + // TODO: simplify this logic + { + const auto & params_spec = params.speculative; + + auto params_dft = params; + + params_dft.n_parallel = 1; + params_dft.n_ctx = params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); + params_dft.devices = params_spec.devices; + params_dft.model = params_spec.mparams_dft; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; + + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + } + + params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + + auto mparams_dft = common_model_params_to_llama(params_dft); + + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); + if (model_dft == nullptr) { + LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); + return 1; + } + + params.speculative.model_dft = model_dft.get(); + params.speculative.cparams_dft = common_context_params_to_llama(params_dft); + } + + // Tokenize the prompt + std::vector<llama_token> inp; + inp = common_tokenize(ctx_tgt, params.prompt, true, true); + + if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) { + LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + + return 1; + } + + if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) { + LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt)); + + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); + } + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + // used to determine end of generation + bool has_eos = false; + + // ================================================ + // everything until here is standard initialization + // the relevant stuff for speculative decoding starts here + + const auto t_enc_start = ggml_time_us(); + + // target model sampling context + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + + // eval the prompt + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + // all tokens currently in the target context + llama_tokens prompt_tgt(inp.begin(), inp.end() - 1); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); + + int n_past = inp.size() - 1; + + // init the speculator + const auto & params_spec = params.speculative; + + struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + + common_speculative_begin(spec, prompt_tgt); + + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + + const auto t_enc_end = ggml_time_us(); + + const auto t_dec_start = ggml_time_us(); + + while (true) { + // optionally, generate draft tokens that can be appended to the target batch + // + // this is the most important part of the speculation. the more probable tokens that are provided here + // the better the performance will be. in theory, this computation can be performed asynchronously and even + // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens + // from a cache or lookup tables. + // + llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); + + //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); + + // always have a token to evaluate from before - id_last + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + + // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] + { + // do not waste time on small drafts + if (draft.size() < (size_t) params_spec.n_min) { + draft.clear(); + } + + for (size_t i = 0; i < draft.size(); ++i) { + common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + } + + //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); + + llama_decode(ctx_tgt, batch_tgt); + } + + // sample from the full target batch and return the accepted tokens based on the target sampler + // + // for each token to be accepted, the sampler would have to sample that same token + // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the + // available logits from the batch and sample the next token until we run out of logits or the sampler + // disagrees with the draft + // + const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + + //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); + + GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + + n_past += ids.size() - 1; + n_drafted += draft.size(); // note: we ignore the discarded small drafts + n_accept += ids.size() - 1; + n_predict += ids.size(); + + // process the accepted tokens and update contexts + // + // this is the standard token post-processing that we normally do + // in this case, we do it for a group of accepted tokens at once + // + for (size_t i = 0; i < ids.size(); ++i) { + prompt_tgt.push_back(id_last); + + id_last = ids[i]; + + if (llama_vocab_is_eog(vocab, id_last)) { + has_eos = true; + break; + } + + const std::string token_str = common_token_to_piece(ctx_tgt, id_last); + + if (params.use_color && i + 1 < ids.size()) { + LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); + } + } + + LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); + + { + LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); + + llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1); + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + } + + auto t_dec_end = ggml_time_us(); + + const int n_input = inp.size(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", params_spec.n_max); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\n"); + LOG_INF("draft:\n\n"); + + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); + + llama_batch_free(batch_tgt); + + common_sampler_free(smpl); + common_speculative_free(spec); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/llama.cpp/examples/speculative/CMakeLists.txt b/llama.cpp/examples/speculative/CMakeLists.txt new file mode 100644 index 0000000..c84196b --- /dev/null +++ b/llama.cpp/examples/speculative/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-speculative) +add_executable(${TARGET} speculative.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/speculative/README.md b/llama.cpp/examples/speculative/README.md new file mode 100644 index 0000000..36ab370 --- /dev/null +++ b/llama.cpp/examples/speculative/README.md @@ -0,0 +1,9 @@ +# llama.cpp/examples/speculative + +Demonstration of speculative decoding and tree-based speculative decoding techniques + +More info: + +- https://github.com/ggml-org/llama.cpp/pull/2926 +- https://github.com/ggml-org/llama.cpp/pull/3624 +- https://github.com/ggml-org/llama.cpp/pull/5625 diff --git a/llama.cpp/examples/speculative/speculative.cpp b/llama.cpp/examples/speculative/speculative.cpp new file mode 100644 index 0000000..3e5cf5f --- /dev/null +++ b/llama.cpp/examples/speculative/speculative.cpp @@ -0,0 +1,648 @@ +#include "arg.h" +#include "common.h" +#include "sampling.h" +#include "log.h" +#include "llama.h" + +#include <algorithm> +#include <cstdio> +#include <cstring> +#include <random> +#include <set> +#include <string> +#include <vector> + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + +struct seq_draft { + bool active = false; + bool drafting = false; + bool skip = false; + + int i_batch_dft = 0; + std::vector<int> i_batch_tgt; + + std::vector<llama_token> tokens; + std::vector<std::vector<llama_token_data>> dists; + + struct common_sampler * smpl = nullptr; +}; + +int main(int argc, char ** argv) { + common_params params; + + // needed to get candidate probs even for temp <= 0.0 + params.sampling.n_probs = 128; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { + return 1; + } + + if (params.n_predict < -1) { + LOG_ERR("%s: --n-predict must be >= -1\n", __func__); + return 1; + } + + common_init(); + + if (params.speculative.mparams_dft.path.empty()) { + LOG_ERR("%s: --model-draft is required\n", __func__); + return 1; + } + + // max number of parallel drafting sequences (i.e. tree branches) + const int n_seq_dft = params.n_parallel; + + // probability threshold for splitting a draft branch (only for n_seq_dft > 1) + const float p_draft_split = params.speculative.p_split; + + std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed); + std::uniform_real_distribution<> u_dist; + + // init llama.cpp + llama_backend_init(); + llama_numa_init(params.numa); + + llama_model * model_tgt = NULL; + llama_model * model_dft = NULL; + + llama_context * ctx_tgt = NULL; + llama_context * ctx_dft = NULL; + + // load the target model + auto llama_init_tgt = common_init_from_params(params); + + model_tgt = llama_init_tgt->model(); + ctx_tgt = llama_init_tgt->context(); + + // load the draft model + params.devices = params.speculative.devices; + params.model = params.speculative.mparams_dft; + params.n_gpu_layers = params.speculative.n_gpu_layers; + if (params.speculative.cpuparams.n_threads > 0) { + params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + } + + params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + + auto llama_init_dft = common_init_from_params(params); + + model_dft = llama_init_dft->model(); + ctx_dft = llama_init_dft->context(); + + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); + + const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); + LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(vocab_dft); + LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__); + LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + return 1; + } + + if ( + llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || + llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft) + ) { + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + return 1; + } + + { + const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); + const int vocab_diff = n_vocab_tgt > n_vocab_dft + ? n_vocab_tgt - n_vocab_dft + : n_vocab_dft - n_vocab_tgt; + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); + LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return 1; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); + LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return 1; + } + } + } + + auto * mem_tgt = llama_get_memory(ctx_tgt); + auto * mem_dft = llama_get_memory(ctx_dft); + + // Tokenize the prompt + std::vector<llama_token> inp; + inp = common_tokenize(ctx_tgt, params.prompt, true, true); + + const int max_context_size = llama_n_ctx(ctx_tgt); + const int max_tokens_list_size = max_context_size - 4; + + if ((int) inp.size() > max_tokens_list_size) { + LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size); + return 1; + } + + LOG("\n\n"); + + for (auto id : inp) { + LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); + } + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + + // eval the prompt with both models + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); + + const auto t_enc_end = ggml_time_us(); + + // the 2 models should have the same vocab + //GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft)); + + // how many tokens to draft each time + int n_draft = params.speculative.n_max; + + int n_predict = 0; + int n_drafted = 0; + int n_accept = 0; + + int n_past_tgt = inp.size(); + int n_past_dft = inp.size(); + + // used to determine end of generation + bool has_eos = false; + + // target model sampling context (reuse the llama_context's sampling instance) + struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); + + // draft sequence data + std::vector<seq_draft> drafts(n_seq_dft); + + for (int s = 0; s < n_seq_dft; ++s) { + // allocate llama_sampler for each draft sequence + drafts[s].smpl = common_sampler_init(model_dft, params.sampling); + } + + llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); + + const auto t_dec_start = ggml_time_us(); + + // sample from the last token of the prompt + drafts[0].i_batch_tgt.resize(1); + drafts[0].i_batch_tgt[0] = 0; + + while (true) { + std::set<int> active_seqs = {}; + + // print current draft sequences + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + active_seqs.insert(s); + const auto & tokens = drafts[s].tokens; + + LOG_DBG("draft %d: %s\n", s, string_from(ctx_dft, tokens).c_str()); + } + + int i_dft = 0; + int s_keep = 0; + + llama_token token_id; + std::string token_str; + + // loop until we fail to accept a drafted token or we run out of drafted tokens + while (true) { + + // check if the target token matches any of the drafts + // for stochastic sampling, attempt to match the token with the drafted tokens + { + bool accept = false; + if (params.sampling.temp > 0) { + // stochastic verification + common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); + + auto & dist_tgt = *common_sampler_get_candidates(smpl, true); + + float p_tgt = 0.0f; + float p_dft = 0.0f; + + while (active_seqs.size() > 0) { + // randomly select a sequence to verify from active sequences + std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1); + int s = *std::next(active_seqs.begin(), u_int_dist(rng)); + if (i_dft >= (int) drafts[s].tokens.size()) { + drafts[s].active = false; + active_seqs.erase(s); + continue; + } + if (accept) { + // if we already accepted a token, we can skip the rest + if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) { + drafts[s].active = false; + active_seqs.erase(s); + } + continue; + } + + LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); + float r = u_dist(rng); + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true }; + + //GGML_ASSERT(dist_tgt.size <= dist_dft.size); + + // acquire the token probabilities assigned by the draft and target models + for (size_t i = 0; i < dist_tgt.size; i++) { + if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { + p_tgt = dist_tgt.data[i].p; + break; + } + } + for (size_t i = 0; i < dist_dft.size; i++) { + if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) { + p_dft = dist_dft.data[i].p; + break; + } + } + LOG_DBG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt); + if (r <= p_tgt / p_dft) { + s_keep = s; + accept = true; + token_id = drafts[s].tokens[i_dft]; + token_str = common_token_to_piece(ctx_tgt, token_id); + common_sampler_accept(smpl, token_id, true); + + LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); + break; + } else { + LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str()); + drafts[s].active = false; + + // calculate residual probability + GGML_ASSERT(dist_tgt.sorted); + GGML_ASSERT(dist_dft.sorted); + + // sort dist by id + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.id < b.id; + }); + + float sum_probs = 0.0f; + + for (size_t i = 0; i < dist_tgt.size; i++) { + if (i < dist_dft.size) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + } else { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p); + } + + sum_probs += dist_tgt.data[i].p; + } + + for (size_t i = 0; i < dist_tgt.size; i++) { + dist_tgt.data[i].p /= sum_probs; + } + + // sort dist_tgt by p desc + std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { + return a.p > b.p; + }); + } + + active_seqs.erase(s); + for (int i = 0; i < n_seq_dft; i++) { + if (i == s) { + continue; + } + if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + // synchronize active status for sequences with the same drafted token + drafts[i].active = drafts[i].active && accept; + if (!drafts[i].active) { + active_seqs.erase(s); + } + } + } + } + + if (!accept) { + // all drafted tokens were rejected + // sample from the target model + LOG_DBG("all drafted tokens were rejected, sampling from residual distribution\n"); + std::vector<float> probs(dist_tgt.size); + for (size_t i = 0; i < dist_tgt.size; ++i) { + probs[i] = dist_tgt.data[i].p; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + + token_id = dist_tgt.data[idx].id; + common_sampler_accept(smpl, token_id, true); + token_str = common_token_to_piece(ctx_tgt, token_id); + } + } else { + // greedy verification + + // sample from the target model + LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); + + common_sampler_accept(smpl, token_id, true); + + token_str = common_token_to_piece(ctx_tgt, token_id); + + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) { + LOG_DBG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str()); + + s_keep = s; + accept = true; + } else { + drafts[s].active = false; + } + } + } + + if (llama_vocab_is_eog(vocab_tgt, token_id)) { + has_eos = true; + } + ++n_predict; + + if (accept) { + ++n_accept; + ++n_past_tgt; + ++n_past_dft; + ++i_dft; + if (params.use_color) { + // Color token according to its origin sequence + LOG("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); + } + continue; + } else { + LOG("%s", token_str.c_str()); + break; + } + } + } + + { + LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str()); + + // TODO: simplify + { + LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); + + llama_memory_seq_keep(mem_dft, s_keep); + llama_memory_seq_cp (mem_dft, s_keep, 0, -1, -1); + llama_memory_seq_keep(mem_dft, 0); + + llama_memory_seq_rm (mem_tgt, s_keep, n_past_tgt, -1); + llama_memory_seq_keep(mem_tgt, s_keep); + llama_memory_seq_cp (mem_tgt, s_keep, 0, -1, -1); + llama_memory_seq_keep(mem_tgt, 0); + } + + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].active = false; + drafts[s].tokens.clear(); + drafts[s].i_batch_tgt.clear(); + drafts[s].dists.clear(); + } + // note: will be erased after the speculation phase + drafts[0].tokens.push_back(token_id); + drafts[0].dists.push_back(std::vector<llama_token_data>()); + drafts[0].i_batch_tgt.push_back(0); + + common_batch_clear(batch_dft); + common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); + + llama_memory_seq_rm(mem_dft, 0, n_past_dft, -1); + // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); + llama_decode(ctx_dft, batch_dft); + + ++n_past_dft; + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; + } + + if (drafts[0].smpl) { + common_sampler_free(drafts[0].smpl); + } + drafts[0].smpl = common_sampler_clone(smpl); + + int n_seq_cur = 1; + int n_past_cur = n_past_dft; + + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].active = false; + drafts[s].drafting = false; + } + drafts[0].active = true; + drafts[0].drafting = true; + drafts[0].i_batch_dft = 0; + + common_batch_clear(batch_tgt); + common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + + // sample n_draft tokens from the draft model using tree-based sampling + for (int i = 0; i < n_draft; ++i) { + batch_dft.n_tokens = 0; + + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].skip = false; + } + + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].drafting || drafts[s].skip) { + continue; + } + + common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); + + const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true); + + for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + std::vector<int> sa(1, s); + + // attempt to split the branch if the probability is high enough + for (int f = 1; f < 8; ++f) { + if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) { + LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur); + + llama_memory_seq_rm(mem_dft, n_seq_cur, -1, -1); + llama_memory_seq_cp(mem_dft, s, n_seq_cur, -1, -1); + + // all previous tokens from this branch are now also part of the new branch + for (int t = 0; t < batch_tgt.n_tokens; ++t) { + for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { + if (batch_tgt.seq_id[t][p] == s) { + batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur; + batch_tgt.n_seq_id[t]++; + break; + } + } + } + + // copy the draft state + drafts[n_seq_cur].active = true; + drafts[n_seq_cur].drafting = true; + drafts[n_seq_cur].skip = true; + + drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].dists = drafts[s].dists; + drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; + drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; + + if (drafts[n_seq_cur].smpl) { + common_sampler_free(drafts[n_seq_cur].smpl); + } + drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl); + + sa.push_back(n_seq_cur); + + n_seq_cur++; + } else { + break; + } + } + + // add drafted token for each sequence + for (int is = 0; is < (int) sa.size(); ++is) { + const llama_token id = cur_p->data[is].id; + + const int s = sa[is]; + + common_sampler_accept(drafts[s].smpl, id, true); + + drafts[s].tokens.push_back(id); + // save cur_p.data into drafts[s].dists + drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size}); + + // add unique drafted tokens to the target batch + drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); + + common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); + + // add the token to the batch for batched decoding with the draft model + drafts[s].i_batch_dft = batch_dft.n_tokens; + + common_batch_add(batch_dft, id, n_past_cur, { s }, true); + + if (batch_tgt.n_tokens > n_draft) { + drafts[s].drafting = false; + } + } + } + + // no sequence is drafting anymore + if (batch_dft.n_tokens == 0) { + break; + } + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch_dft); + ++n_past_cur; + ++n_drafted; + + if (batch_tgt.n_tokens > n_draft) { + break; + } + } + + // evaluate the target model on the drafted tokens + { + llama_memory_seq_keep(mem_tgt, 0); + for (int s = 1; s < n_seq_dft; ++s) { + llama_memory_seq_cp(mem_tgt, 0, s, -1, -1); + } + + // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); + llama_decode(ctx_tgt, batch_tgt); + ++n_past_tgt; + } + + // the first token is always proposed by the target model before the speculation loop so we erase it here + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + drafts[s].tokens.erase(drafts[s].tokens.begin()); + drafts[s].dists.erase(drafts[s].dists.begin()); + } + } + + auto t_dec_end = ggml_time_us(); + + LOG("\n\n"); + + LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); + LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + + LOG_INF("\n"); + LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_predict = %d\n", n_predict); + LOG_INF("n_drafted = %d\n", n_drafted); + LOG_INF("n_accept = %d\n", n_accept); + LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); + + LOG_INF("\n"); + LOG_INF("draft:\n\n"); + // TODO: print sampling/grammar timings for all drafts + llama_perf_context_print(ctx_dft); + + LOG_INF("\n"); + LOG_INF("target:\n\n"); + common_perf_print(ctx_tgt, smpl); + + common_sampler_free(smpl); + for (int s = 0; s < n_seq_dft; ++s) { + common_sampler_free(drafts[s].smpl); + } + + llama_batch_free(batch_dft); + + llama_backend_free(); + + LOG("\n\n"); + + return 0; +} diff --git a/llama.cpp/examples/sycl/CMakeLists.txt b/llama.cpp/examples/sycl/CMakeLists.txt new file mode 100644 index 0000000..e4d5083 --- /dev/null +++ b/llama.cpp/examples/sycl/CMakeLists.txt @@ -0,0 +1,9 @@ +# MIT license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT + +set(TARGET llama-ls-sycl-device) +add_executable(${TARGET} ls-sycl-device.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llama.cpp/examples/sycl/README.md b/llama.cpp/examples/sycl/README.md new file mode 100644 index 0000000..8819d87 --- /dev/null +++ b/llama.cpp/examples/sycl/README.md @@ -0,0 +1,41 @@ +# llama.cpp/example/sycl + +This example program provides the tools for llama.cpp for SYCL on Intel GPU. + +## Tool + +|Tool Name| Function|Status| +|-|-|-| +|llama-ls-sycl-device| List all SYCL devices with ID, compute capability, max work group size, ect.|Support| + +### llama-ls-sycl-device + +List all SYCL devices with ID, compute capability, max work group size, ect. + +1. Build the llama.cpp for SYCL for the specified target *(using GGML_SYCL_TARGET)*. + +2. Enable oneAPI running environment *(if GGML_SYCL_TARGET is set to INTEL -default-)* + +``` +source /opt/intel/oneapi/setvars.sh +``` + +3. Execute + +``` +./build/bin/llama-ls-sycl-device +``` + +Check the ID in startup log, like: + +``` +found 2 SYCL devices: +| | | | |Max | |Max |Global | | +| | | | |compute|Max work|sub |mem | | +|ID| Device Type| Name|Version|units |group |group|size | Driver version| +|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------| +| 0| [level_zero:gpu:0]| Intel Arc A770 Graphics| 1.3| 512| 1024| 32| 16225M| 1.3.29138| +| 1| [level_zero:gpu:1]| Intel UHD Graphics 750| 1.3| 32| 512| 32| 62631M| 1.3.29138| + +``` + diff --git a/llama.cpp/examples/sycl/build.sh b/llama.cpp/examples/sycl/build.sh new file mode 100755 index 0000000..635e74f --- /dev/null +++ b/llama.cpp/examples/sycl/build.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# MIT license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT + +mkdir -p build +cd build +source /opt/intel/oneapi/setvars.sh + +#for FP16 +#cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DLLAMA_OPENSSL=OFF # faster for long-prompt inference + +#for FP32 +cmake .. -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DLLAMA_OPENSSL=OFF + +#build example/main +#cmake --build . --config Release --target main + +#build example/llama-bench +#cmake --build . --config Release --target llama-bench + +#build all binary +cmake --build . --config Release -j -v diff --git a/llama.cpp/examples/sycl/ls-sycl-device.cpp b/llama.cpp/examples/sycl/ls-sycl-device.cpp new file mode 100644 index 0000000..74a8b7f --- /dev/null +++ b/llama.cpp/examples/sycl/ls-sycl-device.cpp @@ -0,0 +1,13 @@ +// +// MIT license +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: MIT +// + + +#include "ggml-sycl.h" + +int main() { + ggml_backend_sycl_print_sycl_devices(); + return 0; +} diff --git a/llama.cpp/examples/sycl/run-llama2.sh b/llama.cpp/examples/sycl/run-llama2.sh new file mode 100755 index 0000000..d33f82f --- /dev/null +++ b/llama.cpp/examples/sycl/run-llama2.sh @@ -0,0 +1,31 @@ +#!/usr/bin/env bash + +# MIT license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT +export ONEAPI_DEVICE_SELECTOR="level_zero:0" +source /opt/intel/oneapi/setvars.sh + +#export GGML_SYCL_DEBUG=1 + +#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. + +INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" +MODEL_FILE=models/llama-2-7b.Q4_0.gguf +NGL=99 +CONTEXT=4096 + +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 + +LOAD_MODE='--mmap' +if [ $# -gt 0 ]; then + GGML_SYCL_DEVICE=$1 + echo "use $GGML_SYCL_DEVICE as main GPU" + #use signle GPU only + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE} + +else + #use multiple GPUs with same max compute units + ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE} +fi diff --git a/llama.cpp/examples/sycl/test.sh b/llama.cpp/examples/sycl/test.sh new file mode 100755 index 0000000..140c191 --- /dev/null +++ b/llama.cpp/examples/sycl/test.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# MIT license +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: MIT + +Help() { + cat << EOF +Usage: $(basename "$0") [OPTIONS] + +This script processes files with specified options. + +Options: + -h, --help Display this help message and exit. + -c, --context <value> Set context length. Bigger need more memory. + -p, --promote <value> Prompt to start generation with. + -m, --model <value> Full model file path. + -mg,--main-gpu <value> Set main GPU ID (0 - n) for single GPU mode. + -sm,--split-mode <value> How to split the model across multiple GPUs, one of: + - none: use one GPU only + - layer (default): split layers and KV across GPUs + - row: split rows across GPUs + -ngl,--n-gpu-layers <value> Max. number of layers to store in VRAM (default: -1) + -lv,--log-verbosity <value> Set the verbosity threshold. Messages with a higher verbosity will be + ignored. Values: + - 0: generic output + - 1: error + - 2: warning + - 3: info + - 4: debug + + +EOF +} + +BIN_FILE=./build/bin/llama-completion +SEED=0 +GPUS_SETTING="" + +INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:" +MODEL_FILE=models/llama-2-7b.Q4_0.gguf +NGL=99 +CONTEXT=4096 +GGML_SYCL_DEVICE=-1 +SPLIT_MODE=layer +LOG_VERBOSE=3 +while [[ $# -gt 0 ]]; do + case "$1" in + -c|--context) + CONTEXT=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -p|--promote) + # Option that is a simple flag (boolean) + INPUT_PROMPT="$2" + # Shift once to consume the option flag + shift + shift + ;; + -m|--model) + MODEL_FILE="$2" + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -mg|--main-gpu) + GGML_SYCL_DEVICE=$2 + SPLIT_MODE=none + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -sm|--split-mode) + SPLIT_MODE=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -ngl|--n-gpu-layers) + NGL=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -lv|--log-verbosity) + LOG_VERBOSE=$2 + # Shift twice to consume both the option flag and its value + shift + shift + ;; + -h|--help) + Help + exit 0 + ;; + *) + # Handle unknown options or stop processing options + echo "Invalid option: $1" + # Optional: exit script or shift to treat remaining as positional args + exit 1 + ;; + esac +done + + + +source /opt/intel/oneapi/setvars.sh + +#export GGML_SYCL_DEBUG=1 + +#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer. + +#support malloc device memory more than 4GB. +export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 +echo "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=${UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS}" + +if [ $GGML_SYCL_DEVICE -ne -1 ]; then + echo "Use $GGML_SYCL_DEVICE as main GPU" + #use signle GPU only + GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}" + export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}" + echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}" +else + echo "Use all Intel GPUs, including iGPU & dGPU" + fi + +echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap " +ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap + diff --git a/llama.cpp/examples/sycl/win-build-sycl.bat b/llama.cpp/examples/sycl/win-build-sycl.bat new file mode 100644 index 0000000..fc8b33b --- /dev/null +++ b/llama.cpp/examples/sycl/win-build-sycl.bat @@ -0,0 +1,31 @@ + +:: MIT license +:: Copyright (C) 2024 Intel Corporation +:: SPDX-License-Identifier: MIT + + +IF not exist build (mkdir build) +cd build +if %errorlevel% neq 0 goto ERROR + +@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force +if %errorlevel% neq 0 goto ERROR + +:: for FP16 +:: faster for long-prompt inference +:: cmake -G "MinGW Makefiles" .. -DLLAMA_OPENSSL=OFF -DGGML_SYCL=ON -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release -DGGML_SYCL_F16=ON + +:: for FP32 +cmake -G "Ninja" .. -DLLAMA_OPENSSL=OFF -DGGML_SYCL=ON -DCMAKE_C_COMPILER=cl -DCMAKE_CXX_COMPILER=icx -DBUILD_SHARED_LIBS=ON -DCMAKE_BUILD_TYPE=Release +if %errorlevel% neq 0 goto ERROR + +:: build all binary +cmake --build . -j +if %errorlevel% neq 0 goto ERROR + +cd .. +exit /B 0 + +:ERROR +echo comomand error: %errorlevel% +exit /B %errorlevel% diff --git a/llama.cpp/examples/sycl/win-run-llama2.bat b/llama.cpp/examples/sycl/win-run-llama2.bat new file mode 100644 index 0000000..1f2dab8 --- /dev/null +++ b/llama.cpp/examples/sycl/win-run-llama2.bat @@ -0,0 +1,11 @@ +:: MIT license +:: Copyright (C) 2024 Intel Corporation +:: SPDX-License-Identifier: MIT + +set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" +@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 +set LOAD_MODE="--mmap" +.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE% diff --git a/llama.cpp/examples/sycl/win-test.bat b/llama.cpp/examples/sycl/win-test.bat new file mode 100644 index 0000000..1f2dab8 --- /dev/null +++ b/llama.cpp/examples/sycl/win-test.bat @@ -0,0 +1,11 @@ +:: MIT license +:: Copyright (C) 2024 Intel Corporation +:: SPDX-License-Identifier: MIT + +set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" +@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force + +:: support malloc device memory more than 4GB. +set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 +set LOAD_MODE="--mmap" +.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE% diff --git a/llama.cpp/examples/training/CMakeLists.txt b/llama.cpp/examples/training/CMakeLists.txt new file mode 100644 index 0000000..64afe6d --- /dev/null +++ b/llama.cpp/examples/training/CMakeLists.txt @@ -0,0 +1,5 @@ +set(TARGET llama-finetune) +add_executable(${TARGET} finetune.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) diff --git a/llama.cpp/examples/training/README.md b/llama.cpp/examples/training/README.md new file mode 100644 index 0000000..df42527 --- /dev/null +++ b/llama.cpp/examples/training/README.md @@ -0,0 +1,17 @@ +# llama.cpp/examples/training + +This directory contains examples related to language model training using llama.cpp/GGML. +So far finetuning is technically functional (for FP32 models and limited hardware setups) but the code is very much WIP. +Finetuning of Stories 260K and LLaMA 3.2 1b seems to work with 24 GB of memory. +**For CPU training, compile llama.cpp without any additional backends such as CUDA.** +**For CUDA training, use the maximum number of GPU layers.** + +Proof of concept: + +``` sh +export model_name=llama_3.2-1b && export quantization=f32 +./build/bin/llama-finetune --file wikitext-2-raw/wiki.test.raw -ngl 999 --model models/${model_name}-${quantization}.gguf -c 512 -b 512 -ub 512 +./build/bin/llama-perplexity --file wikitext-2-raw/wiki.test.raw -ngl 999 --model finetuned-model.gguf +``` + +The perplexity value of the finetuned model should be lower after training on the test set for 2 epochs. diff --git a/llama.cpp/examples/training/finetune.cpp b/llama.cpp/examples/training/finetune.cpp new file mode 100644 index 0000000..c82de8d --- /dev/null +++ b/llama.cpp/examples/training/finetune.cpp @@ -0,0 +1,97 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" + +#include <cmath> +#include <cstdio> +#include <cstring> +#include <ctime> +#include <vector> + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +int main(int argc, char ** argv) { + common_params params; + params.escape = false; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) { + return 1; + } + + if (params.use_mmap) { + LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", + __func__); + params.use_mmap = false; + } + if (params.cache_type_k != GGML_TYPE_F32) { + LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_k = GGML_TYPE_F32; + } + if (params.cache_type_v != GGML_TYPE_F32) { + LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__); + params.cache_type_v = GGML_TYPE_F32; + } + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + // load the model and apply lora adapter, if any + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == NULL) { + LOG_ERR("%s: unable to load model\n", __func__); + return 1; + } + + // print system information + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + } + + std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, true); + ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx, tokens, llama_n_ctx(ctx) / 2); + + struct lr_opt & lr = params.lr; + LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n", + ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs, + (unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split); + + struct llama_opt_params lopt_params{ + /*n_ctx_train =*/0, + /*param_filter =*/llama_opt_param_filter_all, + /*param_filter_ud =*/nullptr, + /*get_opt_pars =*/common_opt_lr_pars, + /*get_opt_pars_ud =*/¶ms.lr, + /*optimizer_type =*/params.optimizer, + }; + llama_opt_init(ctx, model, lopt_params); + + const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split); + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_eval = ggml_opt_result_init(); + + for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) { + llama_opt_epoch(ctx, dataset, result_train, result_eval, idata_split, + ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar); + fprintf(stderr, "\n"); + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_eval); + } + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_eval); + + llama_model_save_to_file(model, params.out_file.c_str()); + + llama_backend_free(); + + return 0; +} diff --git a/llama.cpp/examples/ts-type-to-grammar.sh b/llama.cpp/examples/ts-type-to-grammar.sh new file mode 100755 index 0000000..9660504 --- /dev/null +++ b/llama.cpp/examples/ts-type-to-grammar.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# +# ./examples/ts-type-to-grammar.sh "{a:string,b:string,c?:string}" +# python examples/json_schema_to_grammar.py https://json.schemastore.org/tsconfig.json +# +set -euo pipefail + +readonly type="$1" + +# Create a temporary directory +TMPDIR="" +trap 'rm -fR "$TMPDIR"' EXIT +TMPDIR=$(mktemp -d) + +DTS_FILE="$TMPDIR/type.d.ts" +SCHEMA_FILE="$TMPDIR/schema.json" + +echo "export type MyType = $type" > "$DTS_FILE" + +# This is a fork of typescript-json-schema, actively maintained as of March 2024: +# https://github.com/vega/ts-json-schema-generator +npx ts-json-schema-generator --unstable --no-top-ref --path "$DTS_FILE" --type MyType -e none > "$SCHEMA_FILE" + +# Alternative, not actively maintained as of March 2024: +# https://github.com/YousefED/typescript-json-schema +# npx typescript-json-schema --defaultProps --required "$DTS_FILE" MyType | tee "$SCHEMA_FILE" >&2 + +./examples/json_schema_to_grammar.py "$SCHEMA_FILE" |
