summaryrefslogtreecommitdiff
path: root/llama.cpp/examples/llama.android/lib/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/examples/llama.android/lib/src/main/java')
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt14
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt89
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt61
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt132
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt77
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt324
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt590
7 files changed, 1287 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt
new file mode 100644
index 0000000..b72a24e
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt
@@ -0,0 +1,14 @@
+package com.arm.aichat
+
+import android.content.Context
+import com.arm.aichat.internal.InferenceEngineImpl
+
+/**
+ * Main entry point for Arm's AI Chat library.
+ */
+object AiChat {
+ /**
+ * Get the inference engine single instance.
+ */
+ fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
+}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt
new file mode 100644
index 0000000..26c1668
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt
@@ -0,0 +1,89 @@
+package com.arm.aichat
+
+import com.arm.aichat.InferenceEngine.State
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.StateFlow
+
+/**
+ * Interface defining the core LLM inference operations.
+ */
+interface InferenceEngine {
+ /**
+ * Current state of the inference engine
+ */
+ val state: StateFlow<State>
+
+ /**
+ * Load a model from the given path.
+ *
+ * @throws UnsupportedArchitectureException if model architecture not supported
+ */
+ suspend fun loadModel(pathToModel: String)
+
+ /**
+ * Sends a system prompt to the loaded model
+ */
+ suspend fun setSystemPrompt(systemPrompt: String)
+
+ /**
+ * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
+ */
+ fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
+
+ /**
+ * Runs a benchmark with the specified parameters.
+ */
+ suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
+
+ /**
+ * Unloads the currently loaded model.
+ */
+ fun cleanUp()
+
+ /**
+ * Cleans up resources when the engine is no longer needed.
+ */
+ fun destroy()
+
+ /**
+ * States of the inference engine
+ */
+ sealed class State {
+ object Uninitialized : State()
+ object Initializing : State()
+ object Initialized : State()
+
+ object LoadingModel : State()
+ object UnloadingModel : State()
+ object ModelReady : State()
+
+ object Benchmarking : State()
+ object ProcessingSystemPrompt : State()
+ object ProcessingUserPrompt : State()
+
+ object Generating : State()
+
+ data class Error(val exception: Exception) : State()
+ }
+
+ companion object {
+ const val DEFAULT_PREDICT_LENGTH = 1024
+ }
+}
+
+val State.isUninterruptible
+ get() = this is State.Initializing ||
+ this is State.LoadingModel ||
+ this is State.UnloadingModel ||
+ this is State.Benchmarking ||
+ this is State.ProcessingSystemPrompt ||
+ this is State.ProcessingUserPrompt
+
+val State.isModelLoaded: Boolean
+ get() = this is State.ModelReady ||
+ this is State.Benchmarking ||
+ this is State.ProcessingSystemPrompt ||
+ this is State.ProcessingUserPrompt ||
+ this is State.Generating
+
+class UnsupportedArchitectureException : Exception()
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt
new file mode 100644
index 0000000..2f15eef
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt
@@ -0,0 +1,61 @@
+package com.arm.aichat.gguf
+
+import kotlin.collections.get
+
+
+/**
+ * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
+ * The `label` matches what llama‑cli prints.
+ */
+enum class FileType(val code: Int, val label: String) {
+ ALL_F32(0, "all F32"),
+ MOSTLY_F16(1, "F16"),
+ MOSTLY_Q4_0(2, "Q4_0"),
+ MOSTLY_Q4_1(3, "Q4_1"),
+ // 4 removed
+ MOSTLY_Q8_0(7, "Q8_0"),
+ MOSTLY_Q5_0(8, "Q5_0"),
+ MOSTLY_Q5_1(9, "Q5_1"),
+
+ /* K‑quants ------------------------------------------------------------ */
+ MOSTLY_Q2_K (10, "Q2_K - Medium"),
+ MOSTLY_Q3_K_S (11, "Q3_K - Small"),
+ MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
+ MOSTLY_Q3_K_L (13, "Q3_K - Large"),
+ MOSTLY_Q4_K_S (14, "Q4_K - Small"),
+ MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
+ MOSTLY_Q5_K_S (16, "Q5_K - Small"),
+ MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
+ MOSTLY_Q6_K (18, "Q6_K"),
+
+ /* IQ quants ----------------------------------------------------------- */
+ MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
+ MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
+ MOSTLY_Q2_K_S (21, "Q2_K - Small"),
+ MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
+ MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
+ MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
+ MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
+ MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
+ MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
+ MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
+ MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
+ MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
+ MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
+
+ /* BF16 & Ternary ------------------------------------------------------ */
+ MOSTLY_BF16 (32, "BF16"),
+ MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
+ MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
+
+ /* Special flag -------------------------------------------------------- */
+ GUESSED(1024, "(guessed)"),
+
+ UNKNOWN(-1, "unknown");
+
+ companion object {
+ private val map = entries.associateBy(FileType::code)
+
+ fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
+ }
+}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt
new file mode 100644
index 0000000..5e1971a
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt
@@ -0,0 +1,132 @@
+package com.arm.aichat.gguf
+
+import java.io.IOException
+
+
+/**
+ * Structured metadata of GGUF
+ */
+data class GgufMetadata(
+ // Basic file info
+ val version: GgufVersion,
+ val tensorCount: Long,
+ val kvCount: Long,
+
+ // General info
+ val basic: BasicInfo,
+ val author: AuthorInfo? = null,
+ val additional: AdditionalInfo? = null,
+ val architecture: ArchitectureInfo? = null,
+ val baseModels: List<BaseModelInfo>? = null,
+ val tokenizer: TokenizerInfo? = null,
+
+ // Derivative info
+ val dimensions: DimensionsInfo? = null,
+ val attention: AttentionInfo? = null,
+ val rope: RopeInfo? = null,
+ val experts: ExpertsInfo? = null
+) {
+ enum class GgufVersion(val code: Int, val label: String) {
+ /** First public draft; little‑endian only, no alignment key. */
+ LEGACY_V1(1, "Legacy v1"),
+
+ /** Added split‑file support and some extra metadata keys. */
+ EXTENDED_V2(2, "Extended v2"),
+
+ /** Current spec: endian‑aware, mandatory alignment, fully validated. */
+ VALIDATED_V3(3, "Validated v3");
+
+ companion object {
+ fun fromCode(code: Int): GgufVersion =
+ entries.firstOrNull { it.code == code }
+ ?: throw IOException("Unknown GGUF version code $code")
+ }
+
+ override fun toString(): String = "$label (code=$code)"
+ }
+
+ data class BasicInfo(
+ val uuid: String? = null,
+ val name: String? = null,
+ val nameLabel: String? = null,
+ val sizeLabel: String? = null, // Size label like "7B"
+ )
+
+ data class AuthorInfo(
+ val organization: String? = null,
+ val author: String? = null,
+ val doi: String? = null,
+ val url: String? = null,
+ val repoUrl: String? = null,
+ val license: String? = null,
+ val licenseLink: String? = null,
+ )
+
+ data class AdditionalInfo(
+ val type: String? = null,
+ val description: String? = null,
+ val tags: List<String>? = null,
+ val languages: List<String>? = null,
+ )
+
+ data class ArchitectureInfo(
+ val architecture: String? = null,
+ val fileType: Int? = null,
+ val vocabSize: Int? = null,
+ val finetune: String? = null,
+ val quantizationVersion: Int? = null,
+ )
+
+ data class BaseModelInfo(
+ val name: String? = null,
+ val author: String? = null,
+ val version: String? = null,
+ val organization: String? = null,
+ val url: String? = null,
+ val doi: String? = null,
+ val uuid: String? = null,
+ val repoUrl: String? = null,
+ )
+
+ data class TokenizerInfo(
+ val model: String? = null,
+ val bosTokenId: Int? = null,
+ val eosTokenId: Int? = null,
+ val unknownTokenId: Int? = null,
+ val paddingTokenId: Int? = null,
+ val addBosToken: Boolean? = null,
+ val addEosToken: Boolean? = null,
+ val chatTemplate: String? = null,
+ )
+
+ data class DimensionsInfo(
+ val contextLength: Int? = null,
+ val embeddingSize: Int? = null,
+ val blockCount: Int? = null,
+ val feedForwardSize: Int? = null,
+ )
+
+ data class AttentionInfo(
+ val headCount: Int? = null,
+ val headCountKv: Int? = null,
+ val keyLength: Int? = null,
+ val valueLength: Int? = null,
+ val layerNormEpsilon: Float? = null,
+ val layerNormRmsEpsilon: Float? = null,
+ )
+
+ data class RopeInfo(
+ val frequencyBase: Float? = null,
+ val dimensionCount: Int? = null,
+ val scalingType: String? = null,
+ val scalingFactor: Float? = null,
+ val attnFactor: Float? = null,
+ val originalContextLength: Int? = null,
+ val finetuned: Boolean? = null,
+ )
+
+ data class ExpertsInfo(
+ val count: Int? = null,
+ val usedCount: Int? = null,
+ )
+}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt
new file mode 100644
index 0000000..264a6c0
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt
@@ -0,0 +1,77 @@
+package com.arm.aichat.gguf
+
+import android.content.Context
+import android.net.Uri
+import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
+import java.io.File
+import java.io.IOException
+import java.io.InputStream
+
+/**
+ * Interface for reading GGUF metadata from model files.
+ * Use `GgufMetadataReader.create()` to get an instance.
+ */
+interface GgufMetadataReader {
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param file Java File to the GGUF file with absolute path
+ * @return true if file is valid GGUF, otherwise false
+ * @throws InvalidFileFormatException if file format is invalid
+ */
+ suspend fun ensureSourceFileFormat(file: File): Boolean
+
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param context Context for obtaining [android.content.ContentProvider]
+ * @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
+ * @return true if file is valid GGUF, otherwise false
+ * @throws InvalidFileFormatException if file format is invalid
+ */
+ suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
+
+ /**
+ * Reads and parses GGUF metadata from the specified file path.
+ *
+ * @param input the [InputStream] obtained from a readable file or content
+ * @return Structured metadata extracted from the file
+ * @throws IOException if file is damaged or cannot be read
+ * @throws InvalidFileFormatException if file format is invalid
+ */
+ suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
+
+ companion object {
+ private val DEFAULT_SKIP_KEYS = setOf(
+ "tokenizer.chat_template",
+ "tokenizer.ggml.scores",
+ "tokenizer.ggml.tokens",
+ "tokenizer.ggml.token_type"
+ )
+
+ /**
+ * Creates a default GgufMetadataReader instance
+ */
+ fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
+ skipKeys = DEFAULT_SKIP_KEYS,
+ arraySummariseThreshold = 1_000
+ )
+
+ /**
+ * Creates a GgufMetadataReader with custom configuration
+ *
+ * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
+ * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised;
+ * If -1, never summarise.
+ */
+ fun create(
+ skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
+ arraySummariseThreshold: Int = 1_000
+ ): GgufMetadataReader = GgufMetadataReaderImpl(
+ skipKeys = skipKeys,
+ arraySummariseThreshold = arraySummariseThreshold
+ )
+ }
+}
+
+class InvalidFileFormatException : IOException()
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt
new file mode 100644
index 0000000..a293796
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt
@@ -0,0 +1,324 @@
+package com.arm.aichat.internal
+
+import android.content.Context
+import android.util.Log
+import com.arm.aichat.InferenceEngine
+import com.arm.aichat.UnsupportedArchitectureException
+import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
+import dalvik.annotation.optimization.FastNative
+import kotlinx.coroutines.CancellationException
+import kotlinx.coroutines.CoroutineScope
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.ExperimentalCoroutinesApi
+import kotlinx.coroutines.SupervisorJob
+import kotlinx.coroutines.cancel
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.flow.MutableStateFlow
+import kotlinx.coroutines.flow.StateFlow
+import kotlinx.coroutines.flow.asStateFlow
+import kotlinx.coroutines.flow.flow
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.runBlocking
+import kotlinx.coroutines.withContext
+import java.io.File
+import java.io.IOException
+
+/**
+ * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
+ *
+ * This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
+ * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
+ * with the underlying C++ native code.
+ *
+ * The typical usage flow is:
+ * 1. Get instance via [getInstance]
+ * 2. Load a model with [loadModel]
+ * 3. Send prompts with [sendUserPrompt]
+ * 4. Generate responses as token streams
+ * 5. Perform [cleanUp] when done with a model
+ * 6. Properly [destroy] when completely done
+ *
+ * State transitions are managed automatically and validated at each operation.
+ *
+ * @see ai_chat.cpp for the native implementation details
+ */
+internal class InferenceEngineImpl private constructor(
+ private val nativeLibDir: String
+) : InferenceEngine {
+
+ companion object {
+ private val TAG = InferenceEngineImpl::class.java.simpleName
+
+ @Volatile
+ private var instance: InferenceEngine? = null
+
+ /**
+ * Create or obtain [InferenceEngineImpl]'s single instance.
+ *
+ * @param Context for obtaining native library directory
+ * @throws IllegalArgumentException if native library path is invalid
+ * @throws UnsatisfiedLinkError if library failed to load
+ */
+ internal fun getInstance(context: Context) =
+ instance ?: synchronized(this) {
+ val nativeLibDir = context.applicationInfo.nativeLibraryDir
+ require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
+
+ try {
+ Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
+ InferenceEngineImpl(nativeLibDir).also { instance = it }
+ } catch (e: UnsatisfiedLinkError) {
+ Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * JNI methods
+ * @see ai_chat.cpp
+ */
+ @FastNative
+ private external fun init(nativeLibDir: String)
+
+ @FastNative
+ private external fun load(modelPath: String): Int
+
+ @FastNative
+ private external fun prepare(): Int
+
+ @FastNative
+ private external fun systemInfo(): String
+
+ @FastNative
+ private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
+
+ @FastNative
+ private external fun processSystemPrompt(systemPrompt: String): Int
+
+ @FastNative
+ private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
+
+ @FastNative
+ private external fun generateNextToken(): String?
+
+ @FastNative
+ private external fun unload()
+
+ @FastNative
+ private external fun shutdown()
+
+ private val _state =
+ MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
+ override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow()
+
+ private var _readyForSystemPrompt = false
+ @Volatile
+ private var _cancelGeneration = false
+
+ /**
+ * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
+ */
+ @OptIn(ExperimentalCoroutinesApi::class)
+ private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
+ private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
+
+ init {
+ llamaScope.launch {
+ try {
+ check(_state.value is InferenceEngine.State.Uninitialized) {
+ "Cannot load native library in ${_state.value.javaClass.simpleName}!"
+ }
+ _state.value = InferenceEngine.State.Initializing
+ Log.i(TAG, "Loading native library...")
+ System.loadLibrary("ai-chat")
+ init(nativeLibDir)
+ _state.value = InferenceEngine.State.Initialized
+ Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
+
+ } catch (e: Exception) {
+ Log.e(TAG, "Failed to load native library", e)
+ throw e
+ }
+ }
+ }
+
+ /**
+ * Load the LLM
+ */
+ override suspend fun loadModel(pathToModel: String) =
+ withContext(llamaDispatcher) {
+ check(_state.value is InferenceEngine.State.Initialized) {
+ "Cannot load model in ${_state.value.javaClass.simpleName}!"
+ }
+
+ try {
+ Log.i(TAG, "Checking access to model file... \n$pathToModel")
+ File(pathToModel).let {
+ require(it.exists()) { "File not found" }
+ require(it.isFile) { "Not a valid file" }
+ require(it.canRead()) { "Cannot read file" }
+ }
+
+ Log.i(TAG, "Loading model... \n$pathToModel")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.LoadingModel
+ load(pathToModel).let {
+ // TODO-han.yin: find a better way to pass other error codes
+ if (it != 0) throw UnsupportedArchitectureException()
+ }
+ prepare().let {
+ if (it != 0) throw IOException("Failed to prepare resources")
+ }
+ Log.i(TAG, "Model loaded!")
+ _readyForSystemPrompt = true
+
+ _cancelGeneration = false
+ _state.value = InferenceEngine.State.ModelReady
+ } catch (e: Exception) {
+ Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
+ _state.value = InferenceEngine.State.Error(e)
+ throw e
+ }
+ }
+
+ /**
+ * Process the plain text system prompt
+ *
+ * TODO-han.yin: return error code if system prompt not correct processed?
+ */
+ override suspend fun setSystemPrompt(prompt: String) =
+ withContext(llamaDispatcher) {
+ require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
+ check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
+ check(_state.value is InferenceEngine.State.ModelReady) {
+ "Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
+ }
+
+ Log.i(TAG, "Sending system prompt...")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.ProcessingSystemPrompt
+ processSystemPrompt(prompt).let { result ->
+ if (result != 0) {
+ RuntimeException("Failed to process system prompt: $result").also {
+ _state.value = InferenceEngine.State.Error(it)
+ throw it
+ }
+ }
+ }
+ Log.i(TAG, "System prompt processed! Awaiting user prompt...")
+ _state.value = InferenceEngine.State.ModelReady
+ }
+
+ /**
+ * Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
+ */
+ override fun sendUserPrompt(
+ message: String,
+ predictLength: Int,
+ ): Flow<String> = flow {
+ require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
+ check(_state.value is InferenceEngine.State.ModelReady) {
+ "User prompt discarded due to: ${_state.value.javaClass.simpleName}"
+ }
+
+ try {
+ Log.i(TAG, "Sending user prompt...")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.ProcessingUserPrompt
+
+ processUserPrompt(message, predictLength).let { result ->
+ if (result != 0) {
+ Log.e(TAG, "Failed to process user prompt: $result")
+ return@flow
+ }
+ }
+
+ Log.i(TAG, "User prompt processed. Generating assistant prompt...")
+ _state.value = InferenceEngine.State.Generating
+ while (!_cancelGeneration) {
+ generateNextToken()?.let { utf8token ->
+ if (utf8token.isNotEmpty()) emit(utf8token)
+ } ?: break
+ }
+ if (_cancelGeneration) {
+ Log.i(TAG, "Assistant generation aborted per requested.")
+ } else {
+ Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
+ }
+ _state.value = InferenceEngine.State.ModelReady
+ } catch (e: CancellationException) {
+ Log.i(TAG, "Assistant generation's flow collection cancelled.")
+ _state.value = InferenceEngine.State.ModelReady
+ throw e
+ } catch (e: Exception) {
+ Log.e(TAG, "Error during generation!", e)
+ _state.value = InferenceEngine.State.Error(e)
+ throw e
+ }
+ }.flowOn(llamaDispatcher)
+
+ /**
+ * Benchmark the model
+ */
+ override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
+ withContext(llamaDispatcher) {
+ check(_state.value is InferenceEngine.State.ModelReady) {
+ "Benchmark request discarded due to: $state"
+ }
+ Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
+ _readyForSystemPrompt = false // Just to be safe
+ _state.value = InferenceEngine.State.Benchmarking
+ benchModel(pp, tg, pl, nr).also {
+ _state.value = InferenceEngine.State.ModelReady
+ }
+ }
+
+ /**
+ * Unloads the model and frees resources, or reset error states
+ */
+ override fun cleanUp() {
+ _cancelGeneration = true
+ runBlocking(llamaDispatcher) {
+ when (val state = _state.value) {
+ is InferenceEngine.State.ModelReady -> {
+ Log.i(TAG, "Unloading model and free resources...")
+ _readyForSystemPrompt = false
+ _state.value = InferenceEngine.State.UnloadingModel
+
+ unload()
+
+ _state.value = InferenceEngine.State.Initialized
+ Log.i(TAG, "Model unloaded!")
+ Unit
+ }
+
+ is InferenceEngine.State.Error -> {
+ Log.i(TAG, "Resetting error states...")
+ _state.value = InferenceEngine.State.Initialized
+ Log.i(TAG, "States reset!")
+ Unit
+ }
+
+ else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
+ }
+ }
+ }
+
+ /**
+ * Cancel all ongoing coroutines and free GGML backends
+ */
+ override fun destroy() {
+ _cancelGeneration = true
+ runBlocking(llamaDispatcher) {
+ _readyForSystemPrompt = false
+ when(_state.value) {
+ is InferenceEngine.State.Uninitialized -> {}
+ is InferenceEngine.State.Initialized -> shutdown()
+ else -> { unload(); shutdown() }
+ }
+ }
+ llamaScope.cancel()
+ }
+}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt
new file mode 100644
index 0000000..bf250ac
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt
@@ -0,0 +1,590 @@
+package com.arm.aichat.internal.gguf
+
+import android.content.Context
+import android.net.Uri
+import com.arm.aichat.gguf.GgufMetadata
+import com.arm.aichat.gguf.GgufMetadataReader
+import com.arm.aichat.gguf.InvalidFileFormatException
+import java.io.File
+import java.io.IOException
+import java.io.InputStream
+
+
+/**
+ * Utility class to read GGUF model files and extract metadata key-value pairs.
+ * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
+ */
+internal class GgufMetadataReaderImpl(
+ private val skipKeys: Set<String>,
+ private val arraySummariseThreshold: Int,
+) : GgufMetadataReader {
+ companion object {
+ private const val ARCH_LLAMA = "llama"
+ }
+
+ /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
+ enum class MetadataType(val code: Int) {
+ UINT8(0), INT8(1), UINT16(2), INT16(3),
+ UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
+ STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
+ companion object {
+ private val codeMap = entries.associateBy(MetadataType::code)
+ fun fromCode(code: Int): MetadataType = codeMap[code]
+ ?: throw IOException("Unknown metadata value type code: $code")
+ }
+ }
+
+ /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
+ sealed class MetadataValue {
+ data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
+ data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
+ data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
+ data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
+ data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
+ data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
+ data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
+ data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
+ data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
+ data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
+ data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
+ data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
+ data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
+ }
+
+ /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
+ private fun MetadataValue.toPrimitive(): Any = when (this) {
+ is MetadataValue.UInt8 -> value
+ is MetadataValue.Int8 -> value
+ is MetadataValue.UInt16 -> value
+ is MetadataValue.Int16 -> value
+ is MetadataValue.UInt32 -> value
+ is MetadataValue.Int32 -> value
+ is MetadataValue.Float32 -> value
+ is MetadataValue.Bool -> value
+ is MetadataValue.StringVal -> value
+ is MetadataValue.UInt64 -> value
+ is MetadataValue.Int64 -> value
+ is MetadataValue.Float64 -> value
+ is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
+ }
+
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param context Context for obtaining ContentResolver
+ * @param uri Uri to the GGUF file provided by ContentProvider
+ * @return true if file is valid GGUF, otherwise false
+ */
+ override suspend fun ensureSourceFileFormat(file: File): Boolean =
+ file.inputStream().buffered().use { ensureMagic(it) }
+
+ /**
+ * Reads the magic number from the specified file path.
+ *
+ * @param context Context for obtaining ContentResolver
+ * @param uri Uri to the GGUF file provided by ContentProvider
+ * @return true if file is valid GGUF, otherwise false
+ */
+ override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
+ context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
+
+ /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
+ private fun ensureMagic(input: InputStream): Boolean =
+ ByteArray(4).let {
+ if (input.read(it) != 4) throw IOException("Not a valid file!")
+ it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
+ }
+
+ /**
+ * High‑level entry point: parses a `.gguf` file on disk and returns the fully
+ * populated [GgufMetadata] tree.
+ *
+ * Steps performed internally:
+ * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
+ * 2. Streams through the key‑value section, skipping large blobs if the key
+ * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
+ * 3. Converts the resulting raw map into strongly‑typed sub‑structures
+ * (basic info, tokenizer, rope, etc.).
+ *
+ * The method is STREAMING‑ONLY: tensors are never mapped or loaded into
+ * memory, so even multi‑GB model files can be processed in < 50 ms.
+ *
+ * @param path Absolute or relative filesystem path to a `.gguf` file.
+ * @return A [GgufMetadata] instance containing all recognised metadata plus
+ * an `allMetadata` map with any keys that were not given a dedicated
+ * field.
+ * @throws IOException if the file is not GGUF, the version is unsupported,
+ * or the metadata block is truncated / corrupt.
+ */
+ override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
+ // ── 1. header ──────────────────────────────────────────────────────────
+ // throws on mismatch
+ val version = ensureMagicAndVersion(input)
+ val tensorCount = readLittleLong(input)
+ val kvCount = readLittleLong(input)
+
+ // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
+ val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
+
+ // ── 3. build structured object ────────────────────────────────────────
+ return buildStructured(meta, version, tensorCount, kvCount)
+ }
+
+ /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
+ private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
+ if (!ensureMagic(input)) throw InvalidFileFormatException()
+ return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
+ }
+
+ /**
+ * Read an unsigned 32‑bit little‑endian integer.
+ *
+ * @throws IOException if fewer than four bytes are available.
+ */
+ private fun readLEUInt32(input: InputStream): Int {
+ val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
+ if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
+ return (b3 and 0xFF shl 24) or
+ (b2 and 0xFF shl 16) or
+ (b1 and 0xFF shl 8) or
+ (b0 and 0xFF)
+ }
+
+ /**
+ * Low‑level helper that reads the entire “key-value” section from the current
+ * stream position.
+ *
+ * @param input Open stream positioned JUST AFTER the header.
+ * @param kvCnt Number of key‑value pairs (taken from the header).
+ * @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
+ *
+ * The function honours [skipKeys] and [arraySummariseThreshold] by invoking
+ * [skipValue] or [parseValue] accordingly.
+ */
+ private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
+ mutableMapOf<String, MetadataValue>().apply {
+ repeat(kvCnt.toInt()) {
+ val key = readString(input)
+ val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
+ if (key in skipKeys) {
+ skipValue(input, valueT)
+ } else {
+ this[key] = parseValue(input, valueT)
+ }
+ }
+ }
+
+ /**
+ * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
+ * [GgufMetadata] tree used by the rest of the app.
+ *
+ * Only the keys listed in the spec are copied into dedicated data classes;
+ * everything else is preserved in `GgufMetadata.allMetadata`.
+ *
+ * @param m Raw key/value map.
+ * @param version GGUF file‑format version (enum).
+ * @param tensorCnt Number of tensors (from the header).
+ * @param kvCnt Total metadata pair count (from the header).
+ */
+ private fun buildStructured(
+ m: Map<String, MetadataValue>,
+ version: GgufMetadata.GgufVersion,
+ tensorCnt: Long,
+ kvCnt: Long
+ ): GgufMetadata {
+ // ---------- helpers ----------
+ fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
+ fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
+ fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
+ fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
+ fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
+ fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
+ fun String.strList(): List<String>? =
+ (m[this] as? MetadataValue.ArrayVal)
+ ?.elements
+ ?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
+
+ val arch = "general.architecture".str() ?: ARCH_LLAMA
+
+ // -------------- populate sections ----------------
+ val basic = GgufMetadata.BasicInfo(
+ uuid = "general.uuid".str(),
+ name = "general.basename".str(),
+ nameLabel = "general.name".str(),
+ sizeLabel = "general.size_label".str()
+ )
+
+ val author = GgufMetadata.AuthorInfo(
+ organization = "general.organization".str(),
+ author = "general.author".str(),
+ doi = "general.doi".str(),
+ url = "general.url".str(),
+ repoUrl = "general.repo_url".str(),
+ license = "general.license".str(),
+ licenseLink = "general.license.link".str()
+ ).takeUnless {
+ organization == null && author == null && doi == null &&
+ url == null && repoUrl == null && license == null && licenseLink == null
+ }
+
+ val additional = GgufMetadata.AdditionalInfo(
+ type = "general.type".str(),
+ description = "general.description".str(),
+ tags = "general.tags".strList(),
+ languages = "general.languages".strList()
+ ).takeUnless {
+ type == null && description == null && tags == null && languages == null
+ }
+
+ val architectureInfo = GgufMetadata.ArchitectureInfo(
+ architecture = arch,
+ fileType = "general.file_type".u32(),
+ vocabSize = "$arch.vocab_size".u32(),
+ finetune = "general.finetune".str(),
+ quantizationVersion = "general.quantization_version".u32()
+ ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
+
+ val baseModels = buildList {
+ val n = "general.base_model.count".u32() ?: 0
+ for (i in 0 until n) {
+ fun k(s: String) = "general.base_model.$i.$s"
+ add(
+ GgufMetadata.BaseModelInfo(
+ name = k("name").str(),
+ author = k("author").str(),
+ version = k("version").str(),
+ organization = k("organization").str(),
+ url = k("url").str(),
+ doi = k("doi").str(),
+ uuid = k("uuid").str(),
+ repoUrl = k("repo_url").str(),
+ )
+ )
+ }
+ }.takeIf { it.isNotEmpty() }
+
+ val tokenizer = GgufMetadata.TokenizerInfo(
+ model = "tokenizer.ggml.model".str(),
+ bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
+ eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
+ unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
+ paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
+ addBosToken = "tokenizer.ggml.add_bos_token".bool(),
+ addEosToken = "tokenizer.ggml.add_eos_token".bool(),
+ chatTemplate = "tokenizer.chat_template".str()
+ ).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
+ unknownTokenId == null && paddingTokenId == null &&
+ addBosToken == null && addEosToken == null && chatTemplate == null
+ }
+
+ val dimensions = GgufMetadata.DimensionsInfo(
+ contextLength = "$arch.context_length".u32(),
+ embeddingSize = "$arch.embedding_length".u32(),
+ blockCount = "$arch.block_count".u32(),
+ feedForwardSize = "$arch.feed_forward_length".u32()
+ ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
+
+ val attention = GgufMetadata.AttentionInfo(
+ headCount = "$arch.attention.head_count".u32(),
+ headCountKv = "$arch.attention.head_count_kv".u32(),
+ keyLength = "$arch.attention.key_length".u32(),
+ valueLength = "$arch.attention.value_length".u32(),
+ layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
+ layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
+ ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
+ layerNormEpsilon == null && layerNormRmsEpsilon == null
+ }
+
+ val rope = GgufMetadata.RopeInfo(
+ frequencyBase = "$arch.rope.freq_base".f32(),
+ dimensionCount = "$arch.rope.dimension_count".u32(),
+ scalingType = "$arch.rope.scaling.type".str(),
+ scalingFactor = "$arch.rope.scaling.factor".f32(),
+ attnFactor = "$arch.rope.scaling.attn_factor".f32(),
+ originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
+ finetuned = "$arch.rope.scaling.finetuned".bool()
+ ).takeUnless { frequencyBase == null && dimensionCount == null &&
+ scalingType == null && scalingFactor == null && attnFactor == null &&
+ originalContextLength == null && finetuned == null
+ }
+
+ val experts = GgufMetadata.ExpertsInfo(
+ count = "$arch.expert_count".u32(),
+ usedCount = "$arch.expert_used_count".u32()
+ ).takeUnless { count == null && usedCount == null }
+
+ return GgufMetadata(
+ version = version,
+ tensorCount = tensorCnt,
+ kvCount = kvCnt,
+ basic = basic,
+ author = author,
+ additional = additional,
+ architecture = architectureInfo,
+ baseModels = baseModels,
+ tokenizer = tokenizer,
+ dimensions = dimensions,
+ attention = attention,
+ rope = rope,
+ experts = experts
+ )
+ }
+
+ /**
+ * Recursively parses a metadata value of the given type from the input stream.
+ * @param input The input stream positioned at the start of the value.
+ * @param type The metadata value type to parse.
+ */
+ private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
+ MetadataType.UINT8 -> {
+ // 1-byte unsigned integer
+ val byteVal = input.read()
+ if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
+ MetadataValue.UInt8(byteVal.toUByte())
+ }
+ MetadataType.INT8 -> {
+ // 1-byte signed integer
+ val byteVal = input.read()
+ if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
+ MetadataValue.Int8(byteVal.toByte())
+ }
+ MetadataType.UINT16 -> {
+ // 2-byte unsigned integer (little-endian)
+ val bytes = ByteArray(2)
+ if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
+ // Combine two bytes (little-endian) into an unsigned 16-bit value
+ val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
+ MetadataValue.UInt16(u16.toUShort())
+ }
+ MetadataType.INT16 -> {
+ // 2-byte signed integer (little-endian)
+ val bytes = ByteArray(2)
+ if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
+ // Combine to 16-bit and interpret as signed
+ val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
+ MetadataValue.Int16(i16.toShort())
+ }
+ MetadataType.UINT32 -> {
+ // 4-byte unsigned integer (little-endian)
+ val bytes = ByteArray(4)
+ if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
+ // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
+ val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ MetadataValue.UInt32(u32.toUInt())
+ }
+ MetadataType.INT32 -> {
+ // 4-byte signed integer (little-endian)
+ val bytes = ByteArray(4)
+ if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
+ // Combine four bytes into a 32-bit signed int
+ val i32 = (bytes[3].toInt() and 0xFF shl 24) or
+ (bytes[2].toInt() and 0xFF shl 16) or
+ (bytes[1].toInt() and 0xFF shl 8) or
+ (bytes[0].toInt() and 0xFF)
+ MetadataValue.Int32(i32)
+ }
+ MetadataType.FLOAT32 -> {
+ // 4-byte IEEE 754 float (little-endian)
+ val bytes = ByteArray(4)
+ if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
+ // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
+ val bits = (bytes[3].toInt() and 0xFF shl 24) or
+ (bytes[2].toInt() and 0xFF shl 16) or
+ (bytes[1].toInt() and 0xFF shl 8) or
+ (bytes[0].toInt() and 0xFF)
+ val floatVal = Float.fromBits(bits)
+ MetadataValue.Float32(floatVal)
+ }
+ MetadataType.BOOL -> {
+ // 1-byte boolean (0 = false, 1 = true)
+ val byteVal = input.read()
+ if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
+ if (byteVal != 0 && byteVal != 1) {
+ throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
+ }
+ MetadataValue.Bool(byteVal != 0)
+ }
+ MetadataType.STRING -> {
+ // UTF-8 string (length-prefixed with 8-byte length)
+ val str = readString(input)
+ MetadataValue.StringVal(str)
+ }
+ MetadataType.ARRAY -> {
+ val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
+ val len = readLittleLong(input)
+ val count = len.toInt()
+
+ if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
+ // fast‑forward without allocation
+ repeat(count) { skipValue(input, elemType) }
+ MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
+ } else {
+ val list = ArrayList<MetadataValue>(count)
+ repeat(count) { list += parseValue(input, elemType) }
+ MetadataValue.ArrayVal(elemType, list)
+ }
+ }
+ MetadataType.UINT64 -> {
+ // 8-byte unsigned integer (little-endian)
+ val bytes = ByteArray(8)
+ if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
+ // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
+ val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
+ (bytes[6].toULong() and 0xFFuL shl 48) or
+ (bytes[5].toULong() and 0xFFuL shl 40) or
+ (bytes[4].toULong() and 0xFFuL shl 32) or
+ (bytes[3].toULong() and 0xFFuL shl 24) or
+ (bytes[2].toULong() and 0xFFuL shl 16) or
+ (bytes[1].toULong() and 0xFFuL shl 8) or
+ (bytes[0].toULong() and 0xFFuL)
+ MetadataValue.UInt64(u64)
+ }
+ MetadataType.INT64 -> {
+ // 8-byte signed integer (little-endian)
+ val bytes = ByteArray(8)
+ if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
+ // Combine 8 bytes into a signed 64-bit value (Long)
+ val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
+ (bytes[6].toLong() and 0xFFL shl 48) or
+ (bytes[5].toLong() and 0xFFL shl 40) or
+ (bytes[4].toLong() and 0xFFL shl 32) or
+ (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ MetadataValue.Int64(i64)
+ }
+ MetadataType.FLOAT64 -> {
+ // 8-byte IEEE 754 double (little-endian)
+ val bytes = ByteArray(8)
+ if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
+ // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
+ val bits = (bytes[7].toLong() and 0xFFL shl 56) or
+ (bytes[6].toLong() and 0xFFL shl 48) or
+ (bytes[5].toLong() and 0xFFL shl 40) or
+ (bytes[4].toLong() and 0xFFL shl 32) or
+ (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ val doubleVal = Double.fromBits(bits)
+ MetadataValue.Float64(doubleVal)
+ }
+ }
+
+
+ private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
+ this?.takeIf { !it.check() }
+
+ /** Helper: Skip a value in the stream without storing it (still maintains pointer). */
+ private fun skipValue(input: InputStream, type: MetadataType) {
+ when (type) {
+ MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
+ MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
+ MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
+ MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
+ MetadataType.STRING -> {
+ val len = readLittleLong(input); input.skipFully(len)
+ }
+ MetadataType.ARRAY -> {
+ val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
+ val len = readLittleLong(input)
+ repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
+ }
+ }
+ }
+
+ /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
+ private fun readLittleLong(input: InputStream): Long {
+ val bytes = ByteArray(8)
+ input.readFully(bytes)
+
+ // Combine 8 bytes into a 64-bit value (Little Endian).
+ // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
+ // In our context (lengths/counts), such extremely large values are not expected.
+ return (bytes[7].toLong() and 0xFFL shl 56) or
+ (bytes[6].toLong() and 0xFFL shl 48) or
+ (bytes[5].toLong() and 0xFFL shl 40) or
+ (bytes[4].toLong() and 0xFFL shl 32) or
+ (bytes[3].toLong() and 0xFFL shl 24) or
+ (bytes[2].toLong() and 0xFFL shl 16) or
+ (bytes[1].toLong() and 0xFFL shl 8) or
+ (bytes[0].toLong() and 0xFFL)
+ }
+
+ /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
+ private fun readString(input: InputStream): String =
+ // Read 8-byte little-endian length (number of bytes in the string).
+ readLittleLong(input).let { len ->
+ if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
+
+ // Read the UTF-8 bytes of the given length.
+ ByteArray(len.toInt()).let {
+ if (it.isNotEmpty()) input.readFully(it)
+ String(it, Charsets.UTF_8)
+ }
+ }
+
+ /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
+ private fun littleEndianBytesToInt(bytes: ByteArray): Int =
+ // Note: assumes bytes length is 4.
+ (bytes[3].toInt() and 0xFF shl 24) or
+ (bytes[2].toInt() and 0xFF shl 16) or
+ (bytes[1].toInt() and 0xFF shl 8) or
+ (bytes[0].toInt() and 0xFF)
+
+ /**
+ * Robust skip that works the same on JDK 11 and Android’s desugared runtime.
+ *
+ * @param n Number of bytes to advance in the stream.
+ * @throws IOException on premature EOF.
+ */
+ private fun InputStream.skipFully(n: Long) {
+ var remaining = n
+ val scratch = ByteArray(8192) // read‑and‑toss buffer
+ while (remaining > 0) {
+ val skipped = skip(remaining)
+ when {
+ skipped > 0 -> remaining -= skipped // normal fast path
+ skipped == 0L -> {
+ // fallback: read and discard
+ val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
+ if (read == -1) throw IOException("EOF while skipping $n bytes")
+ remaining -= read
+ }
+ else -> throw IOException("Skip returned negative value")
+ }
+ }
+ }
+
+ /**
+ * Extension that keeps reading until the requested number of bytes are filled.
+ * Falls back to `read()` when `skip()` returns 0, which happens on some Android
+ * streams.
+ *
+ * @param buf Destination buffer.
+ * @param len Number of bytes to fill (defaults to `buf.size`).
+ * @throws IOException on premature EOF.
+ */
+ private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
+ var off = 0
+ while (off < len) {
+ val n = read(buf, off, len - off)
+ if (n == -1) throw IOException("EOF after $off of $len bytes")
+ off += n
+ }
+ }
+
+ /**
+ * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
+ * This is used for small fixed‑length reads (e.g. 4‑byte type codes).
+ *
+ * @throws IOException on premature EOF.
+ */
+ private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
+ if (read(it) != n) throw IOException("Unexpected EOF")
+ }
+}