aboutsummaryrefslogtreecommitdiff
path: root/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat
diff options
context:
space:
mode:
authorMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
committerMitja Felicijan <mitja.felicijan@gmail.com>2026-02-12 20:57:17 +0100
commitb333b06772c89d96aacb5490d6a219fba7c09cc6 (patch)
tree211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat
downloadllmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz
Engage!
Diffstat (limited to 'llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat')
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt14
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt89
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt61
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt132
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt77
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt324
-rw-r--r--llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt590
7 files changed, 1287 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt
new file mode 100644
index 0000000..b72a24e
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt
@@ -0,0 +1,14 @@
1package com.arm.aichat
2
3import android.content.Context
4import com.arm.aichat.internal.InferenceEngineImpl
5
6/**
7 * Main entry point for Arm's AI Chat library.
8 */
9object AiChat {
10 /**
11 * Get the inference engine single instance.
12 */
13 fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context)
14}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt
new file mode 100644
index 0000000..26c1668
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt
@@ -0,0 +1,89 @@
1package com.arm.aichat
2
3import com.arm.aichat.InferenceEngine.State
4import kotlinx.coroutines.flow.Flow
5import kotlinx.coroutines.flow.StateFlow
6
7/**
8 * Interface defining the core LLM inference operations.
9 */
10interface InferenceEngine {
11 /**
12 * Current state of the inference engine
13 */
14 val state: StateFlow<State>
15
16 /**
17 * Load a model from the given path.
18 *
19 * @throws UnsupportedArchitectureException if model architecture not supported
20 */
21 suspend fun loadModel(pathToModel: String)
22
23 /**
24 * Sends a system prompt to the loaded model
25 */
26 suspend fun setSystemPrompt(systemPrompt: String)
27
28 /**
29 * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
30 */
31 fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
32
33 /**
34 * Runs a benchmark with the specified parameters.
35 */
36 suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
37
38 /**
39 * Unloads the currently loaded model.
40 */
41 fun cleanUp()
42
43 /**
44 * Cleans up resources when the engine is no longer needed.
45 */
46 fun destroy()
47
48 /**
49 * States of the inference engine
50 */
51 sealed class State {
52 object Uninitialized : State()
53 object Initializing : State()
54 object Initialized : State()
55
56 object LoadingModel : State()
57 object UnloadingModel : State()
58 object ModelReady : State()
59
60 object Benchmarking : State()
61 object ProcessingSystemPrompt : State()
62 object ProcessingUserPrompt : State()
63
64 object Generating : State()
65
66 data class Error(val exception: Exception) : State()
67 }
68
69 companion object {
70 const val DEFAULT_PREDICT_LENGTH = 1024
71 }
72}
73
74val State.isUninterruptible
75 get() = this is State.Initializing ||
76 this is State.LoadingModel ||
77 this is State.UnloadingModel ||
78 this is State.Benchmarking ||
79 this is State.ProcessingSystemPrompt ||
80 this is State.ProcessingUserPrompt
81
82val State.isModelLoaded: Boolean
83 get() = this is State.ModelReady ||
84 this is State.Benchmarking ||
85 this is State.ProcessingSystemPrompt ||
86 this is State.ProcessingUserPrompt ||
87 this is State.Generating
88
89class UnsupportedArchitectureException : Exception()
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt
new file mode 100644
index 0000000..2f15eef
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt
@@ -0,0 +1,61 @@
1package com.arm.aichat.gguf
2
3import kotlin.collections.get
4
5
6/**
7 * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`).
8 * The `label` matches what llama‑cli prints.
9 */
10enum class FileType(val code: Int, val label: String) {
11 ALL_F32(0, "all F32"),
12 MOSTLY_F16(1, "F16"),
13 MOSTLY_Q4_0(2, "Q4_0"),
14 MOSTLY_Q4_1(3, "Q4_1"),
15 // 4 removed
16 MOSTLY_Q8_0(7, "Q8_0"),
17 MOSTLY_Q5_0(8, "Q5_0"),
18 MOSTLY_Q5_1(9, "Q5_1"),
19
20 /* K‑quants ------------------------------------------------------------ */
21 MOSTLY_Q2_K (10, "Q2_K - Medium"),
22 MOSTLY_Q3_K_S (11, "Q3_K - Small"),
23 MOSTLY_Q3_K_M (12, "Q3_K - Medium"),
24 MOSTLY_Q3_K_L (13, "Q3_K - Large"),
25 MOSTLY_Q4_K_S (14, "Q4_K - Small"),
26 MOSTLY_Q4_K_M (15, "Q4_K - Medium"),
27 MOSTLY_Q5_K_S (16, "Q5_K - Small"),
28 MOSTLY_Q5_K_M (17, "Q5_K - Medium"),
29 MOSTLY_Q6_K (18, "Q6_K"),
30
31 /* IQ quants ----------------------------------------------------------- */
32 MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"),
33 MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"),
34 MOSTLY_Q2_K_S (21, "Q2_K - Small"),
35 MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"),
36 MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"),
37 MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"),
38 MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"),
39 MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"),
40 MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"),
41 MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"),
42 MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"),
43 MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"),
44 MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"),
45
46 /* BF16 & Ternary ------------------------------------------------------ */
47 MOSTLY_BF16 (32, "BF16"),
48 MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"),
49 MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"),
50
51 /* Special flag -------------------------------------------------------- */
52 GUESSED(1024, "(guessed)"),
53
54 UNKNOWN(-1, "unknown");
55
56 companion object {
57 private val map = entries.associateBy(FileType::code)
58
59 fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN
60 }
61}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt
new file mode 100644
index 0000000..5e1971a
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt
@@ -0,0 +1,132 @@
1package com.arm.aichat.gguf
2
3import java.io.IOException
4
5
6/**
7 * Structured metadata of GGUF
8 */
9data class GgufMetadata(
10 // Basic file info
11 val version: GgufVersion,
12 val tensorCount: Long,
13 val kvCount: Long,
14
15 // General info
16 val basic: BasicInfo,
17 val author: AuthorInfo? = null,
18 val additional: AdditionalInfo? = null,
19 val architecture: ArchitectureInfo? = null,
20 val baseModels: List<BaseModelInfo>? = null,
21 val tokenizer: TokenizerInfo? = null,
22
23 // Derivative info
24 val dimensions: DimensionsInfo? = null,
25 val attention: AttentionInfo? = null,
26 val rope: RopeInfo? = null,
27 val experts: ExpertsInfo? = null
28) {
29 enum class GgufVersion(val code: Int, val label: String) {
30 /** First public draft; little‑endian only, no alignment key. */
31 LEGACY_V1(1, "Legacy v1"),
32
33 /** Added split‑file support and some extra metadata keys. */
34 EXTENDED_V2(2, "Extended v2"),
35
36 /** Current spec: endian‑aware, mandatory alignment, fully validated. */
37 VALIDATED_V3(3, "Validated v3");
38
39 companion object {
40 fun fromCode(code: Int): GgufVersion =
41 entries.firstOrNull { it.code == code }
42 ?: throw IOException("Unknown GGUF version code $code")
43 }
44
45 override fun toString(): String = "$label (code=$code)"
46 }
47
48 data class BasicInfo(
49 val uuid: String? = null,
50 val name: String? = null,
51 val nameLabel: String? = null,
52 val sizeLabel: String? = null, // Size label like "7B"
53 )
54
55 data class AuthorInfo(
56 val organization: String? = null,
57 val author: String? = null,
58 val doi: String? = null,
59 val url: String? = null,
60 val repoUrl: String? = null,
61 val license: String? = null,
62 val licenseLink: String? = null,
63 )
64
65 data class AdditionalInfo(
66 val type: String? = null,
67 val description: String? = null,
68 val tags: List<String>? = null,
69 val languages: List<String>? = null,
70 )
71
72 data class ArchitectureInfo(
73 val architecture: String? = null,
74 val fileType: Int? = null,
75 val vocabSize: Int? = null,
76 val finetune: String? = null,
77 val quantizationVersion: Int? = null,
78 )
79
80 data class BaseModelInfo(
81 val name: String? = null,
82 val author: String? = null,
83 val version: String? = null,
84 val organization: String? = null,
85 val url: String? = null,
86 val doi: String? = null,
87 val uuid: String? = null,
88 val repoUrl: String? = null,
89 )
90
91 data class TokenizerInfo(
92 val model: String? = null,
93 val bosTokenId: Int? = null,
94 val eosTokenId: Int? = null,
95 val unknownTokenId: Int? = null,
96 val paddingTokenId: Int? = null,
97 val addBosToken: Boolean? = null,
98 val addEosToken: Boolean? = null,
99 val chatTemplate: String? = null,
100 )
101
102 data class DimensionsInfo(
103 val contextLength: Int? = null,
104 val embeddingSize: Int? = null,
105 val blockCount: Int? = null,
106 val feedForwardSize: Int? = null,
107 )
108
109 data class AttentionInfo(
110 val headCount: Int? = null,
111 val headCountKv: Int? = null,
112 val keyLength: Int? = null,
113 val valueLength: Int? = null,
114 val layerNormEpsilon: Float? = null,
115 val layerNormRmsEpsilon: Float? = null,
116 )
117
118 data class RopeInfo(
119 val frequencyBase: Float? = null,
120 val dimensionCount: Int? = null,
121 val scalingType: String? = null,
122 val scalingFactor: Float? = null,
123 val attnFactor: Float? = null,
124 val originalContextLength: Int? = null,
125 val finetuned: Boolean? = null,
126 )
127
128 data class ExpertsInfo(
129 val count: Int? = null,
130 val usedCount: Int? = null,
131 )
132}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt
new file mode 100644
index 0000000..264a6c0
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt
@@ -0,0 +1,77 @@
1package com.arm.aichat.gguf
2
3import android.content.Context
4import android.net.Uri
5import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl
6import java.io.File
7import java.io.IOException
8import java.io.InputStream
9
10/**
11 * Interface for reading GGUF metadata from model files.
12 * Use `GgufMetadataReader.create()` to get an instance.
13 */
14interface GgufMetadataReader {
15 /**
16 * Reads the magic number from the specified file path.
17 *
18 * @param file Java File to the GGUF file with absolute path
19 * @return true if file is valid GGUF, otherwise false
20 * @throws InvalidFileFormatException if file format is invalid
21 */
22 suspend fun ensureSourceFileFormat(file: File): Boolean
23
24 /**
25 * Reads the magic number from the specified file path.
26 *
27 * @param context Context for obtaining [android.content.ContentProvider]
28 * @param uri Uri to the GGUF file provided by [android.content.ContentProvider]
29 * @return true if file is valid GGUF, otherwise false
30 * @throws InvalidFileFormatException if file format is invalid
31 */
32 suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean
33
34 /**
35 * Reads and parses GGUF metadata from the specified file path.
36 *
37 * @param input the [InputStream] obtained from a readable file or content
38 * @return Structured metadata extracted from the file
39 * @throws IOException if file is damaged or cannot be read
40 * @throws InvalidFileFormatException if file format is invalid
41 */
42 suspend fun readStructuredMetadata(input: InputStream): GgufMetadata
43
44 companion object {
45 private val DEFAULT_SKIP_KEYS = setOf(
46 "tokenizer.chat_template",
47 "tokenizer.ggml.scores",
48 "tokenizer.ggml.tokens",
49 "tokenizer.ggml.token_type"
50 )
51
52 /**
53 * Creates a default GgufMetadataReader instance
54 */
55 fun create(): GgufMetadataReader = GgufMetadataReaderImpl(
56 skipKeys = DEFAULT_SKIP_KEYS,
57 arraySummariseThreshold = 1_000
58 )
59
60 /**
61 * Creates a GgufMetadataReader with custom configuration
62 *
63 * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map)
64 * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised;
65 * If -1, never summarise.
66 */
67 fun create(
68 skipKeys: Set<String> = DEFAULT_SKIP_KEYS,
69 arraySummariseThreshold: Int = 1_000
70 ): GgufMetadataReader = GgufMetadataReaderImpl(
71 skipKeys = skipKeys,
72 arraySummariseThreshold = arraySummariseThreshold
73 )
74 }
75}
76
77class InvalidFileFormatException : IOException()
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt
new file mode 100644
index 0000000..a293796
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt
@@ -0,0 +1,324 @@
1package com.arm.aichat.internal
2
3import android.content.Context
4import android.util.Log
5import com.arm.aichat.InferenceEngine
6import com.arm.aichat.UnsupportedArchitectureException
7import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance
8import dalvik.annotation.optimization.FastNative
9import kotlinx.coroutines.CancellationException
10import kotlinx.coroutines.CoroutineScope
11import kotlinx.coroutines.Dispatchers
12import kotlinx.coroutines.ExperimentalCoroutinesApi
13import kotlinx.coroutines.SupervisorJob
14import kotlinx.coroutines.cancel
15import kotlinx.coroutines.flow.Flow
16import kotlinx.coroutines.flow.MutableStateFlow
17import kotlinx.coroutines.flow.StateFlow
18import kotlinx.coroutines.flow.asStateFlow
19import kotlinx.coroutines.flow.flow
20import kotlinx.coroutines.flow.flowOn
21import kotlinx.coroutines.launch
22import kotlinx.coroutines.runBlocking
23import kotlinx.coroutines.withContext
24import java.io.File
25import java.io.IOException
26
27/**
28 * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models.
29 *
30 * This class implements a singleton pattern for managing the lifecycle of a single LLM instance.
31 * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety
32 * with the underlying C++ native code.
33 *
34 * The typical usage flow is:
35 * 1. Get instance via [getInstance]
36 * 2. Load a model with [loadModel]
37 * 3. Send prompts with [sendUserPrompt]
38 * 4. Generate responses as token streams
39 * 5. Perform [cleanUp] when done with a model
40 * 6. Properly [destroy] when completely done
41 *
42 * State transitions are managed automatically and validated at each operation.
43 *
44 * @see ai_chat.cpp for the native implementation details
45 */
46internal class InferenceEngineImpl private constructor(
47 private val nativeLibDir: String
48) : InferenceEngine {
49
50 companion object {
51 private val TAG = InferenceEngineImpl::class.java.simpleName
52
53 @Volatile
54 private var instance: InferenceEngine? = null
55
56 /**
57 * Create or obtain [InferenceEngineImpl]'s single instance.
58 *
59 * @param Context for obtaining native library directory
60 * @throws IllegalArgumentException if native library path is invalid
61 * @throws UnsatisfiedLinkError if library failed to load
62 */
63 internal fun getInstance(context: Context) =
64 instance ?: synchronized(this) {
65 val nativeLibDir = context.applicationInfo.nativeLibraryDir
66 require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" }
67
68 try {
69 Log.i(TAG, "Instantiating InferenceEngineImpl,,,")
70 InferenceEngineImpl(nativeLibDir).also { instance = it }
71 } catch (e: UnsatisfiedLinkError) {
72 Log.e(TAG, "Failed to load native library from $nativeLibDir", e)
73 throw e
74 }
75 }
76 }
77
78 /**
79 * JNI methods
80 * @see ai_chat.cpp
81 */
82 @FastNative
83 private external fun init(nativeLibDir: String)
84
85 @FastNative
86 private external fun load(modelPath: String): Int
87
88 @FastNative
89 private external fun prepare(): Int
90
91 @FastNative
92 private external fun systemInfo(): String
93
94 @FastNative
95 private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String
96
97 @FastNative
98 private external fun processSystemPrompt(systemPrompt: String): Int
99
100 @FastNative
101 private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int
102
103 @FastNative
104 private external fun generateNextToken(): String?
105
106 @FastNative
107 private external fun unload()
108
109 @FastNative
110 private external fun shutdown()
111
112 private val _state =
113 MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized)
114 override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow()
115
116 private var _readyForSystemPrompt = false
117 @Volatile
118 private var _cancelGeneration = false
119
120 /**
121 * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations
122 */
123 @OptIn(ExperimentalCoroutinesApi::class)
124 private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1)
125 private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob())
126
127 init {
128 llamaScope.launch {
129 try {
130 check(_state.value is InferenceEngine.State.Uninitialized) {
131 "Cannot load native library in ${_state.value.javaClass.simpleName}!"
132 }
133 _state.value = InferenceEngine.State.Initializing
134 Log.i(TAG, "Loading native library...")
135 System.loadLibrary("ai-chat")
136 init(nativeLibDir)
137 _state.value = InferenceEngine.State.Initialized
138 Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}")
139
140 } catch (e: Exception) {
141 Log.e(TAG, "Failed to load native library", e)
142 throw e
143 }
144 }
145 }
146
147 /**
148 * Load the LLM
149 */
150 override suspend fun loadModel(pathToModel: String) =
151 withContext(llamaDispatcher) {
152 check(_state.value is InferenceEngine.State.Initialized) {
153 "Cannot load model in ${_state.value.javaClass.simpleName}!"
154 }
155
156 try {
157 Log.i(TAG, "Checking access to model file... \n$pathToModel")
158 File(pathToModel).let {
159 require(it.exists()) { "File not found" }
160 require(it.isFile) { "Not a valid file" }
161 require(it.canRead()) { "Cannot read file" }
162 }
163
164 Log.i(TAG, "Loading model... \n$pathToModel")
165 _readyForSystemPrompt = false
166 _state.value = InferenceEngine.State.LoadingModel
167 load(pathToModel).let {
168 // TODO-han.yin: find a better way to pass other error codes
169 if (it != 0) throw UnsupportedArchitectureException()
170 }
171 prepare().let {
172 if (it != 0) throw IOException("Failed to prepare resources")
173 }
174 Log.i(TAG, "Model loaded!")
175 _readyForSystemPrompt = true
176
177 _cancelGeneration = false
178 _state.value = InferenceEngine.State.ModelReady
179 } catch (e: Exception) {
180 Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e)
181 _state.value = InferenceEngine.State.Error(e)
182 throw e
183 }
184 }
185
186 /**
187 * Process the plain text system prompt
188 *
189 * TODO-han.yin: return error code if system prompt not correct processed?
190 */
191 override suspend fun setSystemPrompt(prompt: String) =
192 withContext(llamaDispatcher) {
193 require(prompt.isNotBlank()) { "Cannot process empty system prompt!" }
194 check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" }
195 check(_state.value is InferenceEngine.State.ModelReady) {
196 "Cannot process system prompt in ${_state.value.javaClass.simpleName}!"
197 }
198
199 Log.i(TAG, "Sending system prompt...")
200 _readyForSystemPrompt = false
201 _state.value = InferenceEngine.State.ProcessingSystemPrompt
202 processSystemPrompt(prompt).let { result ->
203 if (result != 0) {
204 RuntimeException("Failed to process system prompt: $result").also {
205 _state.value = InferenceEngine.State.Error(it)
206 throw it
207 }
208 }
209 }
210 Log.i(TAG, "System prompt processed! Awaiting user prompt...")
211 _state.value = InferenceEngine.State.ModelReady
212 }
213
214 /**
215 * Send plain text user prompt to LLM, which starts generating tokens in a [Flow]
216 */
217 override fun sendUserPrompt(
218 message: String,
219 predictLength: Int,
220 ): Flow<String> = flow {
221 require(message.isNotEmpty()) { "User prompt discarded due to being empty!" }
222 check(_state.value is InferenceEngine.State.ModelReady) {
223 "User prompt discarded due to: ${_state.value.javaClass.simpleName}"
224 }
225
226 try {
227 Log.i(TAG, "Sending user prompt...")
228 _readyForSystemPrompt = false
229 _state.value = InferenceEngine.State.ProcessingUserPrompt
230
231 processUserPrompt(message, predictLength).let { result ->
232 if (result != 0) {
233 Log.e(TAG, "Failed to process user prompt: $result")
234 return@flow
235 }
236 }
237
238 Log.i(TAG, "User prompt processed. Generating assistant prompt...")
239 _state.value = InferenceEngine.State.Generating
240 while (!_cancelGeneration) {
241 generateNextToken()?.let { utf8token ->
242 if (utf8token.isNotEmpty()) emit(utf8token)
243 } ?: break
244 }
245 if (_cancelGeneration) {
246 Log.i(TAG, "Assistant generation aborted per requested.")
247 } else {
248 Log.i(TAG, "Assistant generation complete. Awaiting user prompt...")
249 }
250 _state.value = InferenceEngine.State.ModelReady
251 } catch (e: CancellationException) {
252 Log.i(TAG, "Assistant generation's flow collection cancelled.")
253 _state.value = InferenceEngine.State.ModelReady
254 throw e
255 } catch (e: Exception) {
256 Log.e(TAG, "Error during generation!", e)
257 _state.value = InferenceEngine.State.Error(e)
258 throw e
259 }
260 }.flowOn(llamaDispatcher)
261
262 /**
263 * Benchmark the model
264 */
265 override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String =
266 withContext(llamaDispatcher) {
267 check(_state.value is InferenceEngine.State.ModelReady) {
268 "Benchmark request discarded due to: $state"
269 }
270 Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)")
271 _readyForSystemPrompt = false // Just to be safe
272 _state.value = InferenceEngine.State.Benchmarking
273 benchModel(pp, tg, pl, nr).also {
274 _state.value = InferenceEngine.State.ModelReady
275 }
276 }
277
278 /**
279 * Unloads the model and frees resources, or reset error states
280 */
281 override fun cleanUp() {
282 _cancelGeneration = true
283 runBlocking(llamaDispatcher) {
284 when (val state = _state.value) {
285 is InferenceEngine.State.ModelReady -> {
286 Log.i(TAG, "Unloading model and free resources...")
287 _readyForSystemPrompt = false
288 _state.value = InferenceEngine.State.UnloadingModel
289
290 unload()
291
292 _state.value = InferenceEngine.State.Initialized
293 Log.i(TAG, "Model unloaded!")
294 Unit
295 }
296
297 is InferenceEngine.State.Error -> {
298 Log.i(TAG, "Resetting error states...")
299 _state.value = InferenceEngine.State.Initialized
300 Log.i(TAG, "States reset!")
301 Unit
302 }
303
304 else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}")
305 }
306 }
307 }
308
309 /**
310 * Cancel all ongoing coroutines and free GGML backends
311 */
312 override fun destroy() {
313 _cancelGeneration = true
314 runBlocking(llamaDispatcher) {
315 _readyForSystemPrompt = false
316 when(_state.value) {
317 is InferenceEngine.State.Uninitialized -> {}
318 is InferenceEngine.State.Initialized -> shutdown()
319 else -> { unload(); shutdown() }
320 }
321 }
322 llamaScope.cancel()
323 }
324}
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt
new file mode 100644
index 0000000..bf250ac
--- /dev/null
+++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt
@@ -0,0 +1,590 @@
1package com.arm.aichat.internal.gguf
2
3import android.content.Context
4import android.net.Uri
5import com.arm.aichat.gguf.GgufMetadata
6import com.arm.aichat.gguf.GgufMetadataReader
7import com.arm.aichat.gguf.InvalidFileFormatException
8import java.io.File
9import java.io.IOException
10import java.io.InputStream
11
12
13/**
14 * Utility class to read GGUF model files and extract metadata key-value pairs.
15 * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data.
16 */
17internal class GgufMetadataReaderImpl(
18 private val skipKeys: Set<String>,
19 private val arraySummariseThreshold: Int,
20) : GgufMetadataReader {
21 companion object {
22 private const val ARCH_LLAMA = "llama"
23 }
24
25 /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */
26 enum class MetadataType(val code: Int) {
27 UINT8(0), INT8(1), UINT16(2), INT16(3),
28 UINT32(4), INT32(5), FLOAT32(6), BOOL(7),
29 STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12);
30 companion object {
31 private val codeMap = entries.associateBy(MetadataType::code)
32 fun fromCode(code: Int): MetadataType = codeMap[code]
33 ?: throw IOException("Unknown metadata value type code: $code")
34 }
35 }
36
37 /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */
38 sealed class MetadataValue {
39 data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int
40 data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int
41 data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian)
42 data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian)
43 data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian)
44 data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian)
45 data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float
46 data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true)
47 data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed)
48 data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue()
49 data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian)
50 data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian)
51 data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double
52 }
53
54 /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */
55 private fun MetadataValue.toPrimitive(): Any = when (this) {
56 is MetadataValue.UInt8 -> value
57 is MetadataValue.Int8 -> value
58 is MetadataValue.UInt16 -> value
59 is MetadataValue.Int16 -> value
60 is MetadataValue.UInt32 -> value
61 is MetadataValue.Int32 -> value
62 is MetadataValue.Float32 -> value
63 is MetadataValue.Bool -> value
64 is MetadataValue.StringVal -> value
65 is MetadataValue.UInt64 -> value
66 is MetadataValue.Int64 -> value
67 is MetadataValue.Float64 -> value
68 is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() }
69 }
70
71 /**
72 * Reads the magic number from the specified file path.
73 *
74 * @param context Context for obtaining ContentResolver
75 * @param uri Uri to the GGUF file provided by ContentProvider
76 * @return true if file is valid GGUF, otherwise false
77 */
78 override suspend fun ensureSourceFileFormat(file: File): Boolean =
79 file.inputStream().buffered().use { ensureMagic(it) }
80
81 /**
82 * Reads the magic number from the specified file path.
83 *
84 * @param context Context for obtaining ContentResolver
85 * @param uri Uri to the GGUF file provided by ContentProvider
86 * @return true if file is valid GGUF, otherwise false
87 */
88 override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean =
89 context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true
90
91 /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */
92 private fun ensureMagic(input: InputStream): Boolean =
93 ByteArray(4).let {
94 if (input.read(it) != 4) throw IOException("Not a valid file!")
95 it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF"
96 }
97
98 /**
99 * High‑level entry point: parses a `.gguf` file on disk and returns the fully
100 * populated [GgufMetadata] tree.
101 *
102 * Steps performed internally:
103 * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version).
104 * 2. Streams through the key‑value section, skipping large blobs if the key
105 * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold].
106 * 3. Converts the resulting raw map into strongly‑typed sub‑structures
107 * (basic info, tokenizer, rope, etc.).
108 *
109 * The method is STREAMING‑ONLY: tensors are never mapped or loaded into
110 * memory, so even multi‑GB model files can be processed in < 50 ms.
111 *
112 * @param path Absolute or relative filesystem path to a `.gguf` file.
113 * @return A [GgufMetadata] instance containing all recognised metadata plus
114 * an `allMetadata` map with any keys that were not given a dedicated
115 * field.
116 * @throws IOException if the file is not GGUF, the version is unsupported,
117 * or the metadata block is truncated / corrupt.
118 */
119 override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata {
120 // ── 1. header ──────────────────────────────────────────────────────────
121 // throws on mismatch
122 val version = ensureMagicAndVersion(input)
123 val tensorCount = readLittleLong(input)
124 val kvCount = readLittleLong(input)
125
126 // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ──
127 val meta = readMetaMap(input, kvCount) // <String, MetadataValue>
128
129 // ── 3. build structured object ────────────────────────────────────────
130 return buildStructured(meta, version, tensorCount, kvCount)
131 }
132
133 /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */
134 private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion {
135 if (!ensureMagic(input)) throw InvalidFileFormatException()
136 return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input))
137 }
138
139 /**
140 * Read an unsigned 32‑bit little‑endian integer.
141 *
142 * @throws IOException if fewer than four bytes are available.
143 */
144 private fun readLEUInt32(input: InputStream): Int {
145 val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read()
146 if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32")
147 return (b3 and 0xFF shl 24) or
148 (b2 and 0xFF shl 16) or
149 (b1 and 0xFF shl 8) or
150 (b0 and 0xFF)
151 }
152
153 /**
154 * Low‑level helper that reads the entire “key-value” section from the current
155 * stream position.
156 *
157 * @param input Open stream positioned JUST AFTER the header.
158 * @param kvCnt Number of key‑value pairs (taken from the header).
159 * @return Mutable map with one [MetadataValue] for every key that is NOT skipped.
160 *
161 * The function honours [skipKeys] and [arraySummariseThreshold] by invoking
162 * [skipValue] or [parseValue] accordingly.
163 */
164 private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> =
165 mutableMapOf<String, MetadataValue>().apply {
166 repeat(kvCnt.toInt()) {
167 val key = readString(input)
168 val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
169 if (key in skipKeys) {
170 skipValue(input, valueT)
171 } else {
172 this[key] = parseValue(input, valueT)
173 }
174 }
175 }
176
177 /**
178 * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed
179 * [GgufMetadata] tree used by the rest of the app.
180 *
181 * Only the keys listed in the spec are copied into dedicated data classes;
182 * everything else is preserved in `GgufMetadata.allMetadata`.
183 *
184 * @param m Raw key/value map.
185 * @param version GGUF file‑format version (enum).
186 * @param tensorCnt Number of tensors (from the header).
187 * @param kvCnt Total metadata pair count (from the header).
188 */
189 private fun buildStructured(
190 m: Map<String, MetadataValue>,
191 version: GgufMetadata.GgufVersion,
192 tensorCnt: Long,
193 kvCnt: Long
194 ): GgufMetadata {
195 // ---------- helpers ----------
196 fun String.str() = (m[this] as? MetadataValue.StringVal)?.value
197 fun String.bool() = (m[this] as? MetadataValue.Bool)?.value
198 fun String.i32() = (m[this] as? MetadataValue.Int32)?.value
199 fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt()
200 fun String.f32() = (m[this] as? MetadataValue.Float32)?.value
201 fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat()
202 fun String.strList(): List<String>? =
203 (m[this] as? MetadataValue.ArrayVal)
204 ?.elements
205 ?.mapNotNull { (it as? MetadataValue.StringVal)?.value }
206
207 val arch = "general.architecture".str() ?: ARCH_LLAMA
208
209 // -------------- populate sections ----------------
210 val basic = GgufMetadata.BasicInfo(
211 uuid = "general.uuid".str(),
212 name = "general.basename".str(),
213 nameLabel = "general.name".str(),
214 sizeLabel = "general.size_label".str()
215 )
216
217 val author = GgufMetadata.AuthorInfo(
218 organization = "general.organization".str(),
219 author = "general.author".str(),
220 doi = "general.doi".str(),
221 url = "general.url".str(),
222 repoUrl = "general.repo_url".str(),
223 license = "general.license".str(),
224 licenseLink = "general.license.link".str()
225 ).takeUnless {
226 organization == null && author == null && doi == null &&
227 url == null && repoUrl == null && license == null && licenseLink == null
228 }
229
230 val additional = GgufMetadata.AdditionalInfo(
231 type = "general.type".str(),
232 description = "general.description".str(),
233 tags = "general.tags".strList(),
234 languages = "general.languages".strList()
235 ).takeUnless {
236 type == null && description == null && tags == null && languages == null
237 }
238
239 val architectureInfo = GgufMetadata.ArchitectureInfo(
240 architecture = arch,
241 fileType = "general.file_type".u32(),
242 vocabSize = "$arch.vocab_size".u32(),
243 finetune = "general.finetune".str(),
244 quantizationVersion = "general.quantization_version".u32()
245 ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null }
246
247 val baseModels = buildList {
248 val n = "general.base_model.count".u32() ?: 0
249 for (i in 0 until n) {
250 fun k(s: String) = "general.base_model.$i.$s"
251 add(
252 GgufMetadata.BaseModelInfo(
253 name = k("name").str(),
254 author = k("author").str(),
255 version = k("version").str(),
256 organization = k("organization").str(),
257 url = k("url").str(),
258 doi = k("doi").str(),
259 uuid = k("uuid").str(),
260 repoUrl = k("repo_url").str(),
261 )
262 )
263 }
264 }.takeIf { it.isNotEmpty() }
265
266 val tokenizer = GgufMetadata.TokenizerInfo(
267 model = "tokenizer.ggml.model".str(),
268 bosTokenId = "tokenizer.ggml.bos_token_id".u32(),
269 eosTokenId = "tokenizer.ggml.eos_token_id".u32(),
270 unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(),
271 paddingTokenId = "tokenizer.ggml.padding_token_id".u32(),
272 addBosToken = "tokenizer.ggml.add_bos_token".bool(),
273 addEosToken = "tokenizer.ggml.add_eos_token".bool(),
274 chatTemplate = "tokenizer.chat_template".str()
275 ).takeUnless { model == null && bosTokenId == null && eosTokenId == null &&
276 unknownTokenId == null && paddingTokenId == null &&
277 addBosToken == null && addEosToken == null && chatTemplate == null
278 }
279
280 val dimensions = GgufMetadata.DimensionsInfo(
281 contextLength = "$arch.context_length".u32(),
282 embeddingSize = "$arch.embedding_length".u32(),
283 blockCount = "$arch.block_count".u32(),
284 feedForwardSize = "$arch.feed_forward_length".u32()
285 ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null }
286
287 val attention = GgufMetadata.AttentionInfo(
288 headCount = "$arch.attention.head_count".u32(),
289 headCountKv = "$arch.attention.head_count_kv".u32(),
290 keyLength = "$arch.attention.key_length".u32(),
291 valueLength = "$arch.attention.value_length".u32(),
292 layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(),
293 layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(),
294 ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null &&
295 layerNormEpsilon == null && layerNormRmsEpsilon == null
296 }
297
298 val rope = GgufMetadata.RopeInfo(
299 frequencyBase = "$arch.rope.freq_base".f32(),
300 dimensionCount = "$arch.rope.dimension_count".u32(),
301 scalingType = "$arch.rope.scaling.type".str(),
302 scalingFactor = "$arch.rope.scaling.factor".f32(),
303 attnFactor = "$arch.rope.scaling.attn_factor".f32(),
304 originalContextLength = "$arch.rope.scaling.original_context_length".u32(),
305 finetuned = "$arch.rope.scaling.finetuned".bool()
306 ).takeUnless { frequencyBase == null && dimensionCount == null &&
307 scalingType == null && scalingFactor == null && attnFactor == null &&
308 originalContextLength == null && finetuned == null
309 }
310
311 val experts = GgufMetadata.ExpertsInfo(
312 count = "$arch.expert_count".u32(),
313 usedCount = "$arch.expert_used_count".u32()
314 ).takeUnless { count == null && usedCount == null }
315
316 return GgufMetadata(
317 version = version,
318 tensorCount = tensorCnt,
319 kvCount = kvCnt,
320 basic = basic,
321 author = author,
322 additional = additional,
323 architecture = architectureInfo,
324 baseModels = baseModels,
325 tokenizer = tokenizer,
326 dimensions = dimensions,
327 attention = attention,
328 rope = rope,
329 experts = experts
330 )
331 }
332
333 /**
334 * Recursively parses a metadata value of the given type from the input stream.
335 * @param input The input stream positioned at the start of the value.
336 * @param type The metadata value type to parse.
337 */
338 private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) {
339 MetadataType.UINT8 -> {
340 // 1-byte unsigned integer
341 val byteVal = input.read()
342 if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.")
343 MetadataValue.UInt8(byteVal.toUByte())
344 }
345 MetadataType.INT8 -> {
346 // 1-byte signed integer
347 val byteVal = input.read()
348 if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.")
349 MetadataValue.Int8(byteVal.toByte())
350 }
351 MetadataType.UINT16 -> {
352 // 2-byte unsigned integer (little-endian)
353 val bytes = ByteArray(2)
354 if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.")
355 // Combine two bytes (little-endian) into an unsigned 16-bit value
356 val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
357 MetadataValue.UInt16(u16.toUShort())
358 }
359 MetadataType.INT16 -> {
360 // 2-byte signed integer (little-endian)
361 val bytes = ByteArray(2)
362 if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.")
363 // Combine to 16-bit and interpret as signed
364 val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF)
365 MetadataValue.Int16(i16.toShort())
366 }
367 MetadataType.UINT32 -> {
368 // 4-byte unsigned integer (little-endian)
369 val bytes = ByteArray(4)
370 if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.")
371 // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt
372 val u32 = (bytes[3].toLong() and 0xFFL shl 24) or
373 (bytes[2].toLong() and 0xFFL shl 16) or
374 (bytes[1].toLong() and 0xFFL shl 8) or
375 (bytes[0].toLong() and 0xFFL)
376 MetadataValue.UInt32(u32.toUInt())
377 }
378 MetadataType.INT32 -> {
379 // 4-byte signed integer (little-endian)
380 val bytes = ByteArray(4)
381 if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.")
382 // Combine four bytes into a 32-bit signed int
383 val i32 = (bytes[3].toInt() and 0xFF shl 24) or
384 (bytes[2].toInt() and 0xFF shl 16) or
385 (bytes[1].toInt() and 0xFF shl 8) or
386 (bytes[0].toInt() and 0xFF)
387 MetadataValue.Int32(i32)
388 }
389 MetadataType.FLOAT32 -> {
390 // 4-byte IEEE 754 float (little-endian)
391 val bytes = ByteArray(4)
392 if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.")
393 // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float
394 val bits = (bytes[3].toInt() and 0xFF shl 24) or
395 (bytes[2].toInt() and 0xFF shl 16) or
396 (bytes[1].toInt() and 0xFF shl 8) or
397 (bytes[0].toInt() and 0xFF)
398 val floatVal = Float.fromBits(bits)
399 MetadataValue.Float32(floatVal)
400 }
401 MetadataType.BOOL -> {
402 // 1-byte boolean (0 = false, 1 = true)
403 val byteVal = input.read()
404 if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.")
405 if (byteVal != 0 && byteVal != 1) {
406 throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).")
407 }
408 MetadataValue.Bool(byteVal != 0)
409 }
410 MetadataType.STRING -> {
411 // UTF-8 string (length-prefixed with 8-byte length)
412 val str = readString(input)
413 MetadataValue.StringVal(str)
414 }
415 MetadataType.ARRAY -> {
416 val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
417 val len = readLittleLong(input)
418 val count = len.toInt()
419
420 if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) {
421 // fast‑forward without allocation
422 repeat(count) { skipValue(input, elemType) }
423 MetadataValue.StringVal("Array($elemType, $count items) /* summarised */")
424 } else {
425 val list = ArrayList<MetadataValue>(count)
426 repeat(count) { list += parseValue(input, elemType) }
427 MetadataValue.ArrayVal(elemType, list)
428 }
429 }
430 MetadataType.UINT64 -> {
431 // 8-byte unsigned integer (little-endian)
432 val bytes = ByteArray(8)
433 if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.")
434 // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range.
435 val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or
436 (bytes[6].toULong() and 0xFFuL shl 48) or
437 (bytes[5].toULong() and 0xFFuL shl 40) or
438 (bytes[4].toULong() and 0xFFuL shl 32) or
439 (bytes[3].toULong() and 0xFFuL shl 24) or
440 (bytes[2].toULong() and 0xFFuL shl 16) or
441 (bytes[1].toULong() and 0xFFuL shl 8) or
442 (bytes[0].toULong() and 0xFFuL)
443 MetadataValue.UInt64(u64)
444 }
445 MetadataType.INT64 -> {
446 // 8-byte signed integer (little-endian)
447 val bytes = ByteArray(8)
448 if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.")
449 // Combine 8 bytes into a signed 64-bit value (Long)
450 val i64 = (bytes[7].toLong() and 0xFFL shl 56) or
451 (bytes[6].toLong() and 0xFFL shl 48) or
452 (bytes[5].toLong() and 0xFFL shl 40) or
453 (bytes[4].toLong() and 0xFFL shl 32) or
454 (bytes[3].toLong() and 0xFFL shl 24) or
455 (bytes[2].toLong() and 0xFFL shl 16) or
456 (bytes[1].toLong() and 0xFFL shl 8) or
457 (bytes[0].toLong() and 0xFFL)
458 MetadataValue.Int64(i64)
459 }
460 MetadataType.FLOAT64 -> {
461 // 8-byte IEEE 754 double (little-endian)
462 val bytes = ByteArray(8)
463 if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.")
464 // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double
465 val bits = (bytes[7].toLong() and 0xFFL shl 56) or
466 (bytes[6].toLong() and 0xFFL shl 48) or
467 (bytes[5].toLong() and 0xFFL shl 40) or
468 (bytes[4].toLong() and 0xFFL shl 32) or
469 (bytes[3].toLong() and 0xFFL shl 24) or
470 (bytes[2].toLong() and 0xFFL shl 16) or
471 (bytes[1].toLong() and 0xFFL shl 8) or
472 (bytes[0].toLong() and 0xFFL)
473 val doubleVal = Double.fromBits(bits)
474 MetadataValue.Float64(doubleVal)
475 }
476 }
477
478
479 private fun <T> T?.takeUnless(check: T.() -> Boolean): T? =
480 this?.takeIf { !it.check() }
481
482 /** Helper: Skip a value in the stream without storing it (still maintains pointer). */
483 private fun skipValue(input: InputStream, type: MetadataType) {
484 when (type) {
485 MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1)
486 MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2)
487 MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4)
488 MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8)
489 MetadataType.STRING -> {
490 val len = readLittleLong(input); input.skipFully(len)
491 }
492 MetadataType.ARRAY -> {
493 val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4)))
494 val len = readLittleLong(input)
495 repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip
496 }
497 }
498 }
499
500 /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */
501 private fun readLittleLong(input: InputStream): Long {
502 val bytes = ByteArray(8)
503 input.readFully(bytes)
504
505 // Combine 8 bytes into a 64-bit value (Little Endian).
506 // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement).
507 // In our context (lengths/counts), such extremely large values are not expected.
508 return (bytes[7].toLong() and 0xFFL shl 56) or
509 (bytes[6].toLong() and 0xFFL shl 48) or
510 (bytes[5].toLong() and 0xFFL shl 40) or
511 (bytes[4].toLong() and 0xFFL shl 32) or
512 (bytes[3].toLong() and 0xFFL shl 24) or
513 (bytes[2].toLong() and 0xFFL shl 16) or
514 (bytes[1].toLong() and 0xFFL shl 8) or
515 (bytes[0].toLong() and 0xFFL)
516 }
517
518 /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */
519 private fun readString(input: InputStream): String =
520 // Read 8-byte little-endian length (number of bytes in the string).
521 readLittleLong(input).let { len ->
522 if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len")
523
524 // Read the UTF-8 bytes of the given length.
525 ByteArray(len.toInt()).let {
526 if (it.isNotEmpty()) input.readFully(it)
527 String(it, Charsets.UTF_8)
528 }
529 }
530
531 /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */
532 private fun littleEndianBytesToInt(bytes: ByteArray): Int =
533 // Note: assumes bytes length is 4.
534 (bytes[3].toInt() and 0xFF shl 24) or
535 (bytes[2].toInt() and 0xFF shl 16) or
536 (bytes[1].toInt() and 0xFF shl 8) or
537 (bytes[0].toInt() and 0xFF)
538
539 /**
540 * Robust skip that works the same on JDK 11 and Android’s desugared runtime.
541 *
542 * @param n Number of bytes to advance in the stream.
543 * @throws IOException on premature EOF.
544 */
545 private fun InputStream.skipFully(n: Long) {
546 var remaining = n
547 val scratch = ByteArray(8192) // read‑and‑toss buffer
548 while (remaining > 0) {
549 val skipped = skip(remaining)
550 when {
551 skipped > 0 -> remaining -= skipped // normal fast path
552 skipped == 0L -> {
553 // fallback: read and discard
554 val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt())
555 if (read == -1) throw IOException("EOF while skipping $n bytes")
556 remaining -= read
557 }
558 else -> throw IOException("Skip returned negative value")
559 }
560 }
561 }
562
563 /**
564 * Extension that keeps reading until the requested number of bytes are filled.
565 * Falls back to `read()` when `skip()` returns 0, which happens on some Android
566 * streams.
567 *
568 * @param buf Destination buffer.
569 * @param len Number of bytes to fill (defaults to `buf.size`).
570 * @throws IOException on premature EOF.
571 */
572 private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) {
573 var off = 0
574 while (off < len) {
575 val n = read(buf, off, len - off)
576 if (n == -1) throw IOException("EOF after $off of $len bytes")
577 off += n
578 }
579 }
580
581 /**
582 * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array.
583 * This is used for small fixed‑length reads (e.g. 4‑byte type codes).
584 *
585 * @throws IOException on premature EOF.
586 */
587 private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also {
588 if (read(it) != n) throw IOException("Unexpected EOF")
589 }
590}