diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat')
7 files changed, 1287 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt new file mode 100644 index 0000000..b72a24e --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt | |||
| @@ -0,0 +1,14 @@ | |||
| 1 | package com.arm.aichat | ||
| 2 | |||
| 3 | import android.content.Context | ||
| 4 | import com.arm.aichat.internal.InferenceEngineImpl | ||
| 5 | |||
| 6 | /** | ||
| 7 | * Main entry point for Arm's AI Chat library. | ||
| 8 | */ | ||
| 9 | object AiChat { | ||
| 10 | /** | ||
| 11 | * Get the inference engine single instance. | ||
| 12 | */ | ||
| 13 | fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context) | ||
| 14 | } | ||
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt new file mode 100644 index 0000000..26c1668 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt | |||
| @@ -0,0 +1,89 @@ | |||
| 1 | package com.arm.aichat | ||
| 2 | |||
| 3 | import com.arm.aichat.InferenceEngine.State | ||
| 4 | import kotlinx.coroutines.flow.Flow | ||
| 5 | import kotlinx.coroutines.flow.StateFlow | ||
| 6 | |||
| 7 | /** | ||
| 8 | * Interface defining the core LLM inference operations. | ||
| 9 | */ | ||
| 10 | interface InferenceEngine { | ||
| 11 | /** | ||
| 12 | * Current state of the inference engine | ||
| 13 | */ | ||
| 14 | val state: StateFlow<State> | ||
| 15 | |||
| 16 | /** | ||
| 17 | * Load a model from the given path. | ||
| 18 | * | ||
| 19 | * @throws UnsupportedArchitectureException if model architecture not supported | ||
| 20 | */ | ||
| 21 | suspend fun loadModel(pathToModel: String) | ||
| 22 | |||
| 23 | /** | ||
| 24 | * Sends a system prompt to the loaded model | ||
| 25 | */ | ||
| 26 | suspend fun setSystemPrompt(systemPrompt: String) | ||
| 27 | |||
| 28 | /** | ||
| 29 | * Sends a user prompt to the loaded model and returns a Flow of generated tokens. | ||
| 30 | */ | ||
| 31 | fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> | ||
| 32 | |||
| 33 | /** | ||
| 34 | * Runs a benchmark with the specified parameters. | ||
| 35 | */ | ||
| 36 | suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String | ||
| 37 | |||
| 38 | /** | ||
| 39 | * Unloads the currently loaded model. | ||
| 40 | */ | ||
| 41 | fun cleanUp() | ||
| 42 | |||
| 43 | /** | ||
| 44 | * Cleans up resources when the engine is no longer needed. | ||
| 45 | */ | ||
| 46 | fun destroy() | ||
| 47 | |||
| 48 | /** | ||
| 49 | * States of the inference engine | ||
| 50 | */ | ||
| 51 | sealed class State { | ||
| 52 | object Uninitialized : State() | ||
| 53 | object Initializing : State() | ||
| 54 | object Initialized : State() | ||
| 55 | |||
| 56 | object LoadingModel : State() | ||
| 57 | object UnloadingModel : State() | ||
| 58 | object ModelReady : State() | ||
| 59 | |||
| 60 | object Benchmarking : State() | ||
| 61 | object ProcessingSystemPrompt : State() | ||
| 62 | object ProcessingUserPrompt : State() | ||
| 63 | |||
| 64 | object Generating : State() | ||
| 65 | |||
| 66 | data class Error(val exception: Exception) : State() | ||
| 67 | } | ||
| 68 | |||
| 69 | companion object { | ||
| 70 | const val DEFAULT_PREDICT_LENGTH = 1024 | ||
| 71 | } | ||
| 72 | } | ||
| 73 | |||
| 74 | val State.isUninterruptible | ||
| 75 | get() = this is State.Initializing || | ||
| 76 | this is State.LoadingModel || | ||
| 77 | this is State.UnloadingModel || | ||
| 78 | this is State.Benchmarking || | ||
| 79 | this is State.ProcessingSystemPrompt || | ||
| 80 | this is State.ProcessingUserPrompt | ||
| 81 | |||
| 82 | val State.isModelLoaded: Boolean | ||
| 83 | get() = this is State.ModelReady || | ||
| 84 | this is State.Benchmarking || | ||
| 85 | this is State.ProcessingSystemPrompt || | ||
| 86 | this is State.ProcessingUserPrompt || | ||
| 87 | this is State.Generating | ||
| 88 | |||
| 89 | class UnsupportedArchitectureException : Exception() | ||
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt new file mode 100644 index 0000000..2f15eef --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt | |||
| @@ -0,0 +1,61 @@ | |||
| 1 | package com.arm.aichat.gguf | ||
| 2 | |||
| 3 | import kotlin.collections.get | ||
| 4 | |||
| 5 | |||
| 6 | /** | ||
| 7 | * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`). | ||
| 8 | * The `label` matches what llama‑cli prints. | ||
| 9 | */ | ||
| 10 | enum class FileType(val code: Int, val label: String) { | ||
| 11 | ALL_F32(0, "all F32"), | ||
| 12 | MOSTLY_F16(1, "F16"), | ||
| 13 | MOSTLY_Q4_0(2, "Q4_0"), | ||
| 14 | MOSTLY_Q4_1(3, "Q4_1"), | ||
| 15 | // 4 removed | ||
| 16 | MOSTLY_Q8_0(7, "Q8_0"), | ||
| 17 | MOSTLY_Q5_0(8, "Q5_0"), | ||
| 18 | MOSTLY_Q5_1(9, "Q5_1"), | ||
| 19 | |||
| 20 | /* K‑quants ------------------------------------------------------------ */ | ||
| 21 | MOSTLY_Q2_K (10, "Q2_K - Medium"), | ||
| 22 | MOSTLY_Q3_K_S (11, "Q3_K - Small"), | ||
| 23 | MOSTLY_Q3_K_M (12, "Q3_K - Medium"), | ||
| 24 | MOSTLY_Q3_K_L (13, "Q3_K - Large"), | ||
| 25 | MOSTLY_Q4_K_S (14, "Q4_K - Small"), | ||
| 26 | MOSTLY_Q4_K_M (15, "Q4_K - Medium"), | ||
| 27 | MOSTLY_Q5_K_S (16, "Q5_K - Small"), | ||
| 28 | MOSTLY_Q5_K_M (17, "Q5_K - Medium"), | ||
| 29 | MOSTLY_Q6_K (18, "Q6_K"), | ||
| 30 | |||
| 31 | /* IQ quants ----------------------------------------------------------- */ | ||
| 32 | MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"), | ||
| 33 | MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"), | ||
| 34 | MOSTLY_Q2_K_S (21, "Q2_K - Small"), | ||
| 35 | MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"), | ||
| 36 | MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"), | ||
| 37 | MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"), | ||
| 38 | MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"), | ||
| 39 | MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"), | ||
| 40 | MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"), | ||
| 41 | MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"), | ||
| 42 | MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"), | ||
| 43 | MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"), | ||
| 44 | MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"), | ||
| 45 | |||
| 46 | /* BF16 & Ternary ------------------------------------------------------ */ | ||
| 47 | MOSTLY_BF16 (32, "BF16"), | ||
| 48 | MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"), | ||
| 49 | MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"), | ||
| 50 | |||
| 51 | /* Special flag -------------------------------------------------------- */ | ||
| 52 | GUESSED(1024, "(guessed)"), | ||
| 53 | |||
| 54 | UNKNOWN(-1, "unknown"); | ||
| 55 | |||
| 56 | companion object { | ||
| 57 | private val map = entries.associateBy(FileType::code) | ||
| 58 | |||
| 59 | fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN | ||
| 60 | } | ||
| 61 | } | ||
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt new file mode 100644 index 0000000..5e1971a --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt | |||
| @@ -0,0 +1,132 @@ | |||
| 1 | package com.arm.aichat.gguf | ||
| 2 | |||
| 3 | import java.io.IOException | ||
| 4 | |||
| 5 | |||
| 6 | /** | ||
| 7 | * Structured metadata of GGUF | ||
| 8 | */ | ||
| 9 | data class GgufMetadata( | ||
| 10 | // Basic file info | ||
| 11 | val version: GgufVersion, | ||
| 12 | val tensorCount: Long, | ||
| 13 | val kvCount: Long, | ||
| 14 | |||
| 15 | // General info | ||
| 16 | val basic: BasicInfo, | ||
| 17 | val author: AuthorInfo? = null, | ||
| 18 | val additional: AdditionalInfo? = null, | ||
| 19 | val architecture: ArchitectureInfo? = null, | ||
| 20 | val baseModels: List<BaseModelInfo>? = null, | ||
| 21 | val tokenizer: TokenizerInfo? = null, | ||
| 22 | |||
| 23 | // Derivative info | ||
| 24 | val dimensions: DimensionsInfo? = null, | ||
| 25 | val attention: AttentionInfo? = null, | ||
| 26 | val rope: RopeInfo? = null, | ||
| 27 | val experts: ExpertsInfo? = null | ||
| 28 | ) { | ||
| 29 | enum class GgufVersion(val code: Int, val label: String) { | ||
| 30 | /** First public draft; little‑endian only, no alignment key. */ | ||
| 31 | LEGACY_V1(1, "Legacy v1"), | ||
| 32 | |||
| 33 | /** Added split‑file support and some extra metadata keys. */ | ||
| 34 | EXTENDED_V2(2, "Extended v2"), | ||
| 35 | |||
| 36 | /** Current spec: endian‑aware, mandatory alignment, fully validated. */ | ||
| 37 | VALIDATED_V3(3, "Validated v3"); | ||
| 38 | |||
| 39 | companion object { | ||
| 40 | fun fromCode(code: Int): GgufVersion = | ||
| 41 | entries.firstOrNull { it.code == code } | ||
| 42 | ?: throw IOException("Unknown GGUF version code $code") | ||
| 43 | } | ||
| 44 | |||
| 45 | override fun toString(): String = "$label (code=$code)" | ||
| 46 | } | ||
| 47 | |||
| 48 | data class BasicInfo( | ||
| 49 | val uuid: String? = null, | ||
| 50 | val name: String? = null, | ||
| 51 | val nameLabel: String? = null, | ||
| 52 | val sizeLabel: String? = null, // Size label like "7B" | ||
| 53 | ) | ||
| 54 | |||
| 55 | data class AuthorInfo( | ||
| 56 | val organization: String? = null, | ||
| 57 | val author: String? = null, | ||
| 58 | val doi: String? = null, | ||
| 59 | val url: String? = null, | ||
| 60 | val repoUrl: String? = null, | ||
| 61 | val license: String? = null, | ||
| 62 | val licenseLink: String? = null, | ||
| 63 | ) | ||
| 64 | |||
| 65 | data class AdditionalInfo( | ||
| 66 | val type: String? = null, | ||
| 67 | val description: String? = null, | ||
| 68 | val tags: List<String>? = null, | ||
| 69 | val languages: List<String>? = null, | ||
| 70 | ) | ||
| 71 | |||
| 72 | data class ArchitectureInfo( | ||
| 73 | val architecture: String? = null, | ||
| 74 | val fileType: Int? = null, | ||
| 75 | val vocabSize: Int? = null, | ||
| 76 | val finetune: String? = null, | ||
| 77 | val quantizationVersion: Int? = null, | ||
| 78 | ) | ||
| 79 | |||
| 80 | data class BaseModelInfo( | ||
| 81 | val name: String? = null, | ||
| 82 | val author: String? = null, | ||
| 83 | val version: String? = null, | ||
| 84 | val organization: String? = null, | ||
| 85 | val url: String? = null, | ||
| 86 | val doi: String? = null, | ||
| 87 | val uuid: String? = null, | ||
| 88 | val repoUrl: String? = null, | ||
| 89 | ) | ||
| 90 | |||
| 91 | data class TokenizerInfo( | ||
| 92 | val model: String? = null, | ||
| 93 | val bosTokenId: Int? = null, | ||
| 94 | val eosTokenId: Int? = null, | ||
| 95 | val unknownTokenId: Int? = null, | ||
| 96 | val paddingTokenId: Int? = null, | ||
| 97 | val addBosToken: Boolean? = null, | ||
| 98 | val addEosToken: Boolean? = null, | ||
| 99 | val chatTemplate: String? = null, | ||
| 100 | ) | ||
| 101 | |||
| 102 | data class DimensionsInfo( | ||
| 103 | val contextLength: Int? = null, | ||
| 104 | val embeddingSize: Int? = null, | ||
| 105 | val blockCount: Int? = null, | ||
| 106 | val feedForwardSize: Int? = null, | ||
| 107 | ) | ||
| 108 | |||
| 109 | data class AttentionInfo( | ||
| 110 | val headCount: Int? = null, | ||
| 111 | val headCountKv: Int? = null, | ||
| 112 | val keyLength: Int? = null, | ||
| 113 | val valueLength: Int? = null, | ||
| 114 | val layerNormEpsilon: Float? = null, | ||
| 115 | val layerNormRmsEpsilon: Float? = null, | ||
| 116 | ) | ||
| 117 | |||
| 118 | data class RopeInfo( | ||
| 119 | val frequencyBase: Float? = null, | ||
| 120 | val dimensionCount: Int? = null, | ||
| 121 | val scalingType: String? = null, | ||
| 122 | val scalingFactor: Float? = null, | ||
| 123 | val attnFactor: Float? = null, | ||
| 124 | val originalContextLength: Int? = null, | ||
| 125 | val finetuned: Boolean? = null, | ||
| 126 | ) | ||
| 127 | |||
| 128 | data class ExpertsInfo( | ||
| 129 | val count: Int? = null, | ||
| 130 | val usedCount: Int? = null, | ||
| 131 | ) | ||
| 132 | } | ||
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt new file mode 100644 index 0000000..264a6c0 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt | |||
| @@ -0,0 +1,77 @@ | |||
| 1 | package com.arm.aichat.gguf | ||
| 2 | |||
| 3 | import android.content.Context | ||
| 4 | import android.net.Uri | ||
| 5 | import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl | ||
| 6 | import java.io.File | ||
| 7 | import java.io.IOException | ||
| 8 | import java.io.InputStream | ||
| 9 | |||
| 10 | /** | ||
| 11 | * Interface for reading GGUF metadata from model files. | ||
| 12 | * Use `GgufMetadataReader.create()` to get an instance. | ||
| 13 | */ | ||
| 14 | interface GgufMetadataReader { | ||
| 15 | /** | ||
| 16 | * Reads the magic number from the specified file path. | ||
| 17 | * | ||
| 18 | * @param file Java File to the GGUF file with absolute path | ||
| 19 | * @return true if file is valid GGUF, otherwise false | ||
| 20 | * @throws InvalidFileFormatException if file format is invalid | ||
| 21 | */ | ||
| 22 | suspend fun ensureSourceFileFormat(file: File): Boolean | ||
| 23 | |||
| 24 | /** | ||
| 25 | * Reads the magic number from the specified file path. | ||
| 26 | * | ||
| 27 | * @param context Context for obtaining [android.content.ContentProvider] | ||
| 28 | * @param uri Uri to the GGUF file provided by [android.content.ContentProvider] | ||
| 29 | * @return true if file is valid GGUF, otherwise false | ||
| 30 | * @throws InvalidFileFormatException if file format is invalid | ||
| 31 | */ | ||
| 32 | suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean | ||
| 33 | |||
| 34 | /** | ||
| 35 | * Reads and parses GGUF metadata from the specified file path. | ||
| 36 | * | ||
| 37 | * @param input the [InputStream] obtained from a readable file or content | ||
| 38 | * @return Structured metadata extracted from the file | ||
| 39 | * @throws IOException if file is damaged or cannot be read | ||
| 40 | * @throws InvalidFileFormatException if file format is invalid | ||
| 41 | */ | ||
| 42 | suspend fun readStructuredMetadata(input: InputStream): GgufMetadata | ||
| 43 | |||
| 44 | companion object { | ||
| 45 | private val DEFAULT_SKIP_KEYS = setOf( | ||
| 46 | "tokenizer.chat_template", | ||
| 47 | "tokenizer.ggml.scores", | ||
| 48 | "tokenizer.ggml.tokens", | ||
| 49 | "tokenizer.ggml.token_type" | ||
| 50 | ) | ||
| 51 | |||
| 52 | /** | ||
| 53 | * Creates a default GgufMetadataReader instance | ||
| 54 | */ | ||
| 55 | fun create(): GgufMetadataReader = GgufMetadataReaderImpl( | ||
| 56 | skipKeys = DEFAULT_SKIP_KEYS, | ||
| 57 | arraySummariseThreshold = 1_000 | ||
| 58 | ) | ||
| 59 | |||
| 60 | /** | ||
| 61 | * Creates a GgufMetadataReader with custom configuration | ||
| 62 | * | ||
| 63 | * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map) | ||
| 64 | * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised; | ||
| 65 | * If -1, never summarise. | ||
| 66 | */ | ||
| 67 | fun create( | ||
| 68 | skipKeys: Set<String> = DEFAULT_SKIP_KEYS, | ||
| 69 | arraySummariseThreshold: Int = 1_000 | ||
| 70 | ): GgufMetadataReader = GgufMetadataReaderImpl( | ||
| 71 | skipKeys = skipKeys, | ||
| 72 | arraySummariseThreshold = arraySummariseThreshold | ||
| 73 | ) | ||
| 74 | } | ||
| 75 | } | ||
| 76 | |||
| 77 | class InvalidFileFormatException : IOException() | ||
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt new file mode 100644 index 0000000..a293796 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt | |||
| @@ -0,0 +1,324 @@ | |||
| 1 | package com.arm.aichat.internal | ||
| 2 | |||
| 3 | import android.content.Context | ||
| 4 | import android.util.Log | ||
| 5 | import com.arm.aichat.InferenceEngine | ||
| 6 | import com.arm.aichat.UnsupportedArchitectureException | ||
| 7 | import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance | ||
| 8 | import dalvik.annotation.optimization.FastNative | ||
| 9 | import kotlinx.coroutines.CancellationException | ||
| 10 | import kotlinx.coroutines.CoroutineScope | ||
| 11 | import kotlinx.coroutines.Dispatchers | ||
| 12 | import kotlinx.coroutines.ExperimentalCoroutinesApi | ||
| 13 | import kotlinx.coroutines.SupervisorJob | ||
| 14 | import kotlinx.coroutines.cancel | ||
| 15 | import kotlinx.coroutines.flow.Flow | ||
| 16 | import kotlinx.coroutines.flow.MutableStateFlow | ||
| 17 | import kotlinx.coroutines.flow.StateFlow | ||
| 18 | import kotlinx.coroutines.flow.asStateFlow | ||
| 19 | import kotlinx.coroutines.flow.flow | ||
| 20 | import kotlinx.coroutines.flow.flowOn | ||
| 21 | import kotlinx.coroutines.launch | ||
| 22 | import kotlinx.coroutines.runBlocking | ||
| 23 | import kotlinx.coroutines.withContext | ||
| 24 | import java.io.File | ||
| 25 | import java.io.IOException | ||
| 26 | |||
| 27 | /** | ||
| 28 | * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. | ||
| 29 | * | ||
| 30 | * This class implements a singleton pattern for managing the lifecycle of a single LLM instance. | ||
| 31 | * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety | ||
| 32 | * with the underlying C++ native code. | ||
| 33 | * | ||
| 34 | * The typical usage flow is: | ||
| 35 | * 1. Get instance via [getInstance] | ||
| 36 | * 2. Load a model with [loadModel] | ||
| 37 | * 3. Send prompts with [sendUserPrompt] | ||
| 38 | * 4. Generate responses as token streams | ||
| 39 | * 5. Perform [cleanUp] when done with a model | ||
| 40 | * 6. Properly [destroy] when completely done | ||
| 41 | * | ||
| 42 | * State transitions are managed automatically and validated at each operation. | ||
| 43 | * | ||
| 44 | * @see ai_chat.cpp for the native implementation details | ||
| 45 | */ | ||
| 46 | internal class InferenceEngineImpl private constructor( | ||
| 47 | private val nativeLibDir: String | ||
| 48 | ) : InferenceEngine { | ||
| 49 | |||
| 50 | companion object { | ||
| 51 | private val TAG = InferenceEngineImpl::class.java.simpleName | ||
| 52 | |||
| 53 | @Volatile | ||
| 54 | private var instance: InferenceEngine? = null | ||
| 55 | |||
| 56 | /** | ||
| 57 | * Create or obtain [InferenceEngineImpl]'s single instance. | ||
| 58 | * | ||
| 59 | * @param Context for obtaining native library directory | ||
| 60 | * @throws IllegalArgumentException if native library path is invalid | ||
| 61 | * @throws UnsatisfiedLinkError if library failed to load | ||
| 62 | */ | ||
| 63 | internal fun getInstance(context: Context) = | ||
| 64 | instance ?: synchronized(this) { | ||
| 65 | val nativeLibDir = context.applicationInfo.nativeLibraryDir | ||
| 66 | require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" } | ||
| 67 | |||
| 68 | try { | ||
| 69 | Log.i(TAG, "Instantiating InferenceEngineImpl,,,") | ||
| 70 | InferenceEngineImpl(nativeLibDir).also { instance = it } | ||
| 71 | } catch (e: UnsatisfiedLinkError) { | ||
| 72 | Log.e(TAG, "Failed to load native library from $nativeLibDir", e) | ||
| 73 | throw e | ||
| 74 | } | ||
| 75 | } | ||
| 76 | } | ||
| 77 | |||
| 78 | /** | ||
| 79 | * JNI methods | ||
| 80 | * @see ai_chat.cpp | ||
| 81 | */ | ||
| 82 | @FastNative | ||
| 83 | private external fun init(nativeLibDir: String) | ||
| 84 | |||
| 85 | @FastNative | ||
| 86 | private external fun load(modelPath: String): Int | ||
| 87 | |||
| 88 | @FastNative | ||
| 89 | private external fun prepare(): Int | ||
| 90 | |||
| 91 | @FastNative | ||
| 92 | private external fun systemInfo(): String | ||
| 93 | |||
| 94 | @FastNative | ||
| 95 | private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String | ||
| 96 | |||
| 97 | @FastNative | ||
| 98 | private external fun processSystemPrompt(systemPrompt: String): Int | ||
| 99 | |||
| 100 | @FastNative | ||
| 101 | private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int | ||
| 102 | |||
| 103 | @FastNative | ||
| 104 | private external fun generateNextToken(): String? | ||
| 105 | |||
| 106 | @FastNative | ||
| 107 | private external fun unload() | ||
| 108 | |||
| 109 | @FastNative | ||
| 110 | private external fun shutdown() | ||
| 111 | |||
| 112 | private val _state = | ||
| 113 | MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized) | ||
| 114 | override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow() | ||
| 115 | |||
| 116 | private var _readyForSystemPrompt = false | ||
| 117 | @Volatile | ||
| 118 | private var _cancelGeneration = false | ||
| 119 | |||
| 120 | /** | ||
| 121 | * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations | ||
| 122 | */ | ||
| 123 | @OptIn(ExperimentalCoroutinesApi::class) | ||
| 124 | private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1) | ||
| 125 | private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob()) | ||
| 126 | |||
| 127 | init { | ||
| 128 | llamaScope.launch { | ||
| 129 | try { | ||
| 130 | check(_state.value is InferenceEngine.State.Uninitialized) { | ||
| 131 | "Cannot load native library in ${_state.value.javaClass.simpleName}!" | ||
| 132 | } | ||
| 133 | _state.value = InferenceEngine.State.Initializing | ||
| 134 | Log.i(TAG, "Loading native library...") | ||
| 135 | System.loadLibrary("ai-chat") | ||
| 136 | init(nativeLibDir) | ||
| 137 | _state.value = InferenceEngine.State.Initialized | ||
| 138 | Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") | ||
| 139 | |||
| 140 | } catch (e: Exception) { | ||
| 141 | Log.e(TAG, "Failed to load native library", e) | ||
| 142 | throw e | ||
| 143 | } | ||
| 144 | } | ||
| 145 | } | ||
| 146 | |||
| 147 | /** | ||
| 148 | * Load the LLM | ||
| 149 | */ | ||
| 150 | override suspend fun loadModel(pathToModel: String) = | ||
| 151 | withContext(llamaDispatcher) { | ||
| 152 | check(_state.value is InferenceEngine.State.Initialized) { | ||
| 153 | "Cannot load model in ${_state.value.javaClass.simpleName}!" | ||
| 154 | } | ||
| 155 | |||
| 156 | try { | ||
| 157 | Log.i(TAG, "Checking access to model file... \n$pathToModel") | ||
| 158 | File(pathToModel).let { | ||
| 159 | require(it.exists()) { "File not found" } | ||
| 160 | require(it.isFile) { "Not a valid file" } | ||
| 161 | require(it.canRead()) { "Cannot read file" } | ||
| 162 | } | ||
| 163 | |||
| 164 | Log.i(TAG, "Loading model... \n$pathToModel") | ||
| 165 | _readyForSystemPrompt = false | ||
| 166 | _state.value = InferenceEngine.State.LoadingModel | ||
| 167 | load(pathToModel).let { | ||
| 168 | // TODO-han.yin: find a better way to pass other error codes | ||
| 169 | if (it != 0) throw UnsupportedArchitectureException() | ||
| 170 | } | ||
| 171 | prepare().let { | ||
| 172 | if (it != 0) throw IOException("Failed to prepare resources") | ||
| 173 | } | ||
| 174 | Log.i(TAG, "Model loaded!") | ||
| 175 | _readyForSystemPrompt = true | ||
| 176 | |||
| 177 | _cancelGeneration = false | ||
| 178 | _state.value = InferenceEngine.State.ModelReady | ||
| 179 | } catch (e: Exception) { | ||
| 180 | Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e) | ||
| 181 | _state.value = InferenceEngine.State.Error(e) | ||
| 182 | throw e | ||
| 183 | } | ||
| 184 | } | ||
| 185 | |||
| 186 | /** | ||
| 187 | * Process the plain text system prompt | ||
| 188 | * | ||
| 189 | * TODO-han.yin: return error code if system prompt not correct processed? | ||
| 190 | */ | ||
| 191 | override suspend fun setSystemPrompt(prompt: String) = | ||
| 192 | withContext(llamaDispatcher) { | ||
| 193 | require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } | ||
| 194 | check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } | ||
| 195 | check(_state.value is InferenceEngine.State.ModelReady) { | ||
| 196 | "Cannot process system prompt in ${_state.value.javaClass.simpleName}!" | ||
| 197 | } | ||
| 198 | |||
| 199 | Log.i(TAG, "Sending system prompt...") | ||
| 200 | _readyForSystemPrompt = false | ||
| 201 | _state.value = InferenceEngine.State.ProcessingSystemPrompt | ||
| 202 | processSystemPrompt(prompt).let { result -> | ||
| 203 | if (result != 0) { | ||
| 204 | RuntimeException("Failed to process system prompt: $result").also { | ||
| 205 | _state.value = InferenceEngine.State.Error(it) | ||
| 206 | throw it | ||
| 207 | } | ||
| 208 | } | ||
| 209 | } | ||
| 210 | Log.i(TAG, "System prompt processed! Awaiting user prompt...") | ||
| 211 | _state.value = InferenceEngine.State.ModelReady | ||
| 212 | } | ||
| 213 | |||
| 214 | /** | ||
| 215 | * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] | ||
| 216 | */ | ||
| 217 | override fun sendUserPrompt( | ||
| 218 | message: String, | ||
| 219 | predictLength: Int, | ||
| 220 | ): Flow<String> = flow { | ||
| 221 | require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } | ||
| 222 | check(_state.value is InferenceEngine.State.ModelReady) { | ||
| 223 | "User prompt discarded due to: ${_state.value.javaClass.simpleName}" | ||
| 224 | } | ||
| 225 | |||
| 226 | try { | ||
| 227 | Log.i(TAG, "Sending user prompt...") | ||
| 228 | _readyForSystemPrompt = false | ||
| 229 | _state.value = InferenceEngine.State.ProcessingUserPrompt | ||
| 230 | |||
| 231 | processUserPrompt(message, predictLength).let { result -> | ||
| 232 | if (result != 0) { | ||
| 233 | Log.e(TAG, "Failed to process user prompt: $result") | ||
| 234 | return@flow | ||
| 235 | } | ||
| 236 | } | ||
| 237 | |||
| 238 | Log.i(TAG, "User prompt processed. Generating assistant prompt...") | ||
| 239 | _state.value = InferenceEngine.State.Generating | ||
| 240 | while (!_cancelGeneration) { | ||
| 241 | generateNextToken()?.let { utf8token -> | ||
| 242 | if (utf8token.isNotEmpty()) emit(utf8token) | ||
| 243 | } ?: break | ||
| 244 | } | ||
| 245 | if (_cancelGeneration) { | ||
| 246 | Log.i(TAG, "Assistant generation aborted per requested.") | ||
| 247 | } else { | ||
| 248 | Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") | ||
| 249 | } | ||
| 250 | _state.value = InferenceEngine.State.ModelReady | ||
| 251 | } catch (e: CancellationException) { | ||
| 252 | Log.i(TAG, "Assistant generation's flow collection cancelled.") | ||
| 253 | _state.value = InferenceEngine.State.ModelReady | ||
| 254 | throw e | ||
| 255 | } catch (e: Exception) { | ||
| 256 | Log.e(TAG, "Error during generation!", e) | ||
| 257 | _state.value = InferenceEngine.State.Error(e) | ||
| 258 | throw e | ||
| 259 | } | ||
| 260 | }.flowOn(llamaDispatcher) | ||
| 261 | |||
| 262 | /** | ||
| 263 | * Benchmark the model | ||
| 264 | */ | ||
| 265 | override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = | ||
| 266 | withContext(llamaDispatcher) { | ||
| 267 | check(_state.value is InferenceEngine.State.ModelReady) { | ||
| 268 | "Benchmark request discarded due to: $state" | ||
| 269 | } | ||
| 270 | Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") | ||
| 271 | _readyForSystemPrompt = false // Just to be safe | ||
| 272 | _state.value = InferenceEngine.State.Benchmarking | ||
| 273 | benchModel(pp, tg, pl, nr).also { | ||
| 274 | _state.value = InferenceEngine.State.ModelReady | ||
| 275 | } | ||
| 276 | } | ||
| 277 | |||
| 278 | /** | ||
| 279 | * Unloads the model and frees resources, or reset error states | ||
| 280 | */ | ||
| 281 | override fun cleanUp() { | ||
| 282 | _cancelGeneration = true | ||
| 283 | runBlocking(llamaDispatcher) { | ||
| 284 | when (val state = _state.value) { | ||
| 285 | is InferenceEngine.State.ModelReady -> { | ||
| 286 | Log.i(TAG, "Unloading model and free resources...") | ||
| 287 | _readyForSystemPrompt = false | ||
| 288 | _state.value = InferenceEngine.State.UnloadingModel | ||
| 289 | |||
| 290 | unload() | ||
| 291 | |||
| 292 | _state.value = InferenceEngine.State.Initialized | ||
| 293 | Log.i(TAG, "Model unloaded!") | ||
| 294 | Unit | ||
| 295 | } | ||
| 296 | |||
| 297 | is InferenceEngine.State.Error -> { | ||
| 298 | Log.i(TAG, "Resetting error states...") | ||
| 299 | _state.value = InferenceEngine.State.Initialized | ||
| 300 | Log.i(TAG, "States reset!") | ||
| 301 | Unit | ||
| 302 | } | ||
| 303 | |||
| 304 | else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") | ||
| 305 | } | ||
| 306 | } | ||
| 307 | } | ||
| 308 | |||
| 309 | /** | ||
| 310 | * Cancel all ongoing coroutines and free GGML backends | ||
| 311 | */ | ||
| 312 | override fun destroy() { | ||
| 313 | _cancelGeneration = true | ||
| 314 | runBlocking(llamaDispatcher) { | ||
| 315 | _readyForSystemPrompt = false | ||
| 316 | when(_state.value) { | ||
| 317 | is InferenceEngine.State.Uninitialized -> {} | ||
| 318 | is InferenceEngine.State.Initialized -> shutdown() | ||
| 319 | else -> { unload(); shutdown() } | ||
| 320 | } | ||
| 321 | } | ||
| 322 | llamaScope.cancel() | ||
| 323 | } | ||
| 324 | } | ||
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt new file mode 100644 index 0000000..bf250ac --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt | |||
| @@ -0,0 +1,590 @@ | |||
| 1 | package com.arm.aichat.internal.gguf | ||
| 2 | |||
| 3 | import android.content.Context | ||
| 4 | import android.net.Uri | ||
| 5 | import com.arm.aichat.gguf.GgufMetadata | ||
| 6 | import com.arm.aichat.gguf.GgufMetadataReader | ||
| 7 | import com.arm.aichat.gguf.InvalidFileFormatException | ||
| 8 | import java.io.File | ||
| 9 | import java.io.IOException | ||
| 10 | import java.io.InputStream | ||
| 11 | |||
| 12 | |||
| 13 | /** | ||
| 14 | * Utility class to read GGUF model files and extract metadata key-value pairs. | ||
| 15 | * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data. | ||
| 16 | */ | ||
| 17 | internal class GgufMetadataReaderImpl( | ||
| 18 | private val skipKeys: Set<String>, | ||
| 19 | private val arraySummariseThreshold: Int, | ||
| 20 | ) : GgufMetadataReader { | ||
| 21 | companion object { | ||
| 22 | private const val ARCH_LLAMA = "llama" | ||
| 23 | } | ||
| 24 | |||
| 25 | /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */ | ||
| 26 | enum class MetadataType(val code: Int) { | ||
| 27 | UINT8(0), INT8(1), UINT16(2), INT16(3), | ||
| 28 | UINT32(4), INT32(5), FLOAT32(6), BOOL(7), | ||
| 29 | STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); | ||
| 30 | companion object { | ||
| 31 | private val codeMap = entries.associateBy(MetadataType::code) | ||
| 32 | fun fromCode(code: Int): MetadataType = codeMap[code] | ||
| 33 | ?: throw IOException("Unknown metadata value type code: $code") | ||
| 34 | } | ||
| 35 | } | ||
| 36 | |||
| 37 | /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */ | ||
| 38 | sealed class MetadataValue { | ||
| 39 | data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int | ||
| 40 | data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int | ||
| 41 | data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian) | ||
| 42 | data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian) | ||
| 43 | data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian) | ||
| 44 | data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian) | ||
| 45 | data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float | ||
| 46 | data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true) | ||
| 47 | data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed) | ||
| 48 | data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue() | ||
| 49 | data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian) | ||
| 50 | data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian) | ||
| 51 | data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double | ||
| 52 | } | ||
| 53 | |||
| 54 | /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */ | ||
| 55 | private fun MetadataValue.toPrimitive(): Any = when (this) { | ||
| 56 | is MetadataValue.UInt8 -> value | ||
| 57 | is MetadataValue.Int8 -> value | ||
| 58 | is MetadataValue.UInt16 -> value | ||
| 59 | is MetadataValue.Int16 -> value | ||
| 60 | is MetadataValue.UInt32 -> value | ||
| 61 | is MetadataValue.Int32 -> value | ||
| 62 | is MetadataValue.Float32 -> value | ||
| 63 | is MetadataValue.Bool -> value | ||
| 64 | is MetadataValue.StringVal -> value | ||
| 65 | is MetadataValue.UInt64 -> value | ||
| 66 | is MetadataValue.Int64 -> value | ||
| 67 | is MetadataValue.Float64 -> value | ||
| 68 | is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } | ||
| 69 | } | ||
| 70 | |||
| 71 | /** | ||
| 72 | * Reads the magic number from the specified file path. | ||
| 73 | * | ||
| 74 | * @param context Context for obtaining ContentResolver | ||
| 75 | * @param uri Uri to the GGUF file provided by ContentProvider | ||
| 76 | * @return true if file is valid GGUF, otherwise false | ||
| 77 | */ | ||
| 78 | override suspend fun ensureSourceFileFormat(file: File): Boolean = | ||
| 79 | file.inputStream().buffered().use { ensureMagic(it) } | ||
| 80 | |||
| 81 | /** | ||
| 82 | * Reads the magic number from the specified file path. | ||
| 83 | * | ||
| 84 | * @param context Context for obtaining ContentResolver | ||
| 85 | * @param uri Uri to the GGUF file provided by ContentProvider | ||
| 86 | * @return true if file is valid GGUF, otherwise false | ||
| 87 | */ | ||
| 88 | override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean = | ||
| 89 | context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true | ||
| 90 | |||
| 91 | /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */ | ||
| 92 | private fun ensureMagic(input: InputStream): Boolean = | ||
| 93 | ByteArray(4).let { | ||
| 94 | if (input.read(it) != 4) throw IOException("Not a valid file!") | ||
| 95 | it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" | ||
| 96 | } | ||
| 97 | |||
| 98 | /** | ||
| 99 | * High‑level entry point: parses a `.gguf` file on disk and returns the fully | ||
| 100 | * populated [GgufMetadata] tree. | ||
| 101 | * | ||
| 102 | * Steps performed internally: | ||
| 103 | * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version). | ||
| 104 | * 2. Streams through the key‑value section, skipping large blobs if the key | ||
| 105 | * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold]. | ||
| 106 | * 3. Converts the resulting raw map into strongly‑typed sub‑structures | ||
| 107 | * (basic info, tokenizer, rope, etc.). | ||
| 108 | * | ||
| 109 | * The method is STREAMING‑ONLY: tensors are never mapped or loaded into | ||
| 110 | * memory, so even multi‑GB model files can be processed in < 50 ms. | ||
| 111 | * | ||
| 112 | * @param path Absolute or relative filesystem path to a `.gguf` file. | ||
| 113 | * @return A [GgufMetadata] instance containing all recognised metadata plus | ||
| 114 | * an `allMetadata` map with any keys that were not given a dedicated | ||
| 115 | * field. | ||
| 116 | * @throws IOException if the file is not GGUF, the version is unsupported, | ||
| 117 | * or the metadata block is truncated / corrupt. | ||
| 118 | */ | ||
| 119 | override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata { | ||
| 120 | // ── 1. header ────────────────────────────────────────────────────────── | ||
| 121 | // throws on mismatch | ||
| 122 | val version = ensureMagicAndVersion(input) | ||
| 123 | val tensorCount = readLittleLong(input) | ||
| 124 | val kvCount = readLittleLong(input) | ||
| 125 | |||
| 126 | // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── | ||
| 127 | val meta = readMetaMap(input, kvCount) // <String, MetadataValue> | ||
| 128 | |||
| 129 | // ── 3. build structured object ──────────────────────────────────────── | ||
| 130 | return buildStructured(meta, version, tensorCount, kvCount) | ||
| 131 | } | ||
| 132 | |||
| 133 | /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ | ||
| 134 | private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { | ||
| 135 | if (!ensureMagic(input)) throw InvalidFileFormatException() | ||
| 136 | return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) | ||
| 137 | } | ||
| 138 | |||
| 139 | /** | ||
| 140 | * Read an unsigned 32‑bit little‑endian integer. | ||
| 141 | * | ||
| 142 | * @throws IOException if fewer than four bytes are available. | ||
| 143 | */ | ||
| 144 | private fun readLEUInt32(input: InputStream): Int { | ||
| 145 | val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() | ||
| 146 | if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") | ||
| 147 | return (b3 and 0xFF shl 24) or | ||
| 148 | (b2 and 0xFF shl 16) or | ||
| 149 | (b1 and 0xFF shl 8) or | ||
| 150 | (b0 and 0xFF) | ||
| 151 | } | ||
| 152 | |||
| 153 | /** | ||
| 154 | * Low‑level helper that reads the entire “key-value” section from the current | ||
| 155 | * stream position. | ||
| 156 | * | ||
| 157 | * @param input Open stream positioned JUST AFTER the header. | ||
| 158 | * @param kvCnt Number of key‑value pairs (taken from the header). | ||
| 159 | * @return Mutable map with one [MetadataValue] for every key that is NOT skipped. | ||
| 160 | * | ||
| 161 | * The function honours [skipKeys] and [arraySummariseThreshold] by invoking | ||
| 162 | * [skipValue] or [parseValue] accordingly. | ||
| 163 | */ | ||
| 164 | private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> = | ||
| 165 | mutableMapOf<String, MetadataValue>().apply { | ||
| 166 | repeat(kvCnt.toInt()) { | ||
| 167 | val key = readString(input) | ||
| 168 | val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) | ||
| 169 | if (key in skipKeys) { | ||
| 170 | skipValue(input, valueT) | ||
| 171 | } else { | ||
| 172 | this[key] = parseValue(input, valueT) | ||
| 173 | } | ||
| 174 | } | ||
| 175 | } | ||
| 176 | |||
| 177 | /** | ||
| 178 | * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed | ||
| 179 | * [GgufMetadata] tree used by the rest of the app. | ||
| 180 | * | ||
| 181 | * Only the keys listed in the spec are copied into dedicated data classes; | ||
| 182 | * everything else is preserved in `GgufMetadata.allMetadata`. | ||
| 183 | * | ||
| 184 | * @param m Raw key/value map. | ||
| 185 | * @param version GGUF file‑format version (enum). | ||
| 186 | * @param tensorCnt Number of tensors (from the header). | ||
| 187 | * @param kvCnt Total metadata pair count (from the header). | ||
| 188 | */ | ||
| 189 | private fun buildStructured( | ||
| 190 | m: Map<String, MetadataValue>, | ||
| 191 | version: GgufMetadata.GgufVersion, | ||
| 192 | tensorCnt: Long, | ||
| 193 | kvCnt: Long | ||
| 194 | ): GgufMetadata { | ||
| 195 | // ---------- helpers ---------- | ||
| 196 | fun String.str() = (m[this] as? MetadataValue.StringVal)?.value | ||
| 197 | fun String.bool() = (m[this] as? MetadataValue.Bool)?.value | ||
| 198 | fun String.i32() = (m[this] as? MetadataValue.Int32)?.value | ||
| 199 | fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt() | ||
| 200 | fun String.f32() = (m[this] as? MetadataValue.Float32)?.value | ||
| 201 | fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat() | ||
| 202 | fun String.strList(): List<String>? = | ||
| 203 | (m[this] as? MetadataValue.ArrayVal) | ||
| 204 | ?.elements | ||
| 205 | ?.mapNotNull { (it as? MetadataValue.StringVal)?.value } | ||
| 206 | |||
| 207 | val arch = "general.architecture".str() ?: ARCH_LLAMA | ||
| 208 | |||
| 209 | // -------------- populate sections ---------------- | ||
| 210 | val basic = GgufMetadata.BasicInfo( | ||
| 211 | uuid = "general.uuid".str(), | ||
| 212 | name = "general.basename".str(), | ||
| 213 | nameLabel = "general.name".str(), | ||
| 214 | sizeLabel = "general.size_label".str() | ||
| 215 | ) | ||
| 216 | |||
| 217 | val author = GgufMetadata.AuthorInfo( | ||
| 218 | organization = "general.organization".str(), | ||
| 219 | author = "general.author".str(), | ||
| 220 | doi = "general.doi".str(), | ||
| 221 | url = "general.url".str(), | ||
| 222 | repoUrl = "general.repo_url".str(), | ||
| 223 | license = "general.license".str(), | ||
| 224 | licenseLink = "general.license.link".str() | ||
| 225 | ).takeUnless { | ||
| 226 | organization == null && author == null && doi == null && | ||
| 227 | url == null && repoUrl == null && license == null && licenseLink == null | ||
| 228 | } | ||
| 229 | |||
| 230 | val additional = GgufMetadata.AdditionalInfo( | ||
| 231 | type = "general.type".str(), | ||
| 232 | description = "general.description".str(), | ||
| 233 | tags = "general.tags".strList(), | ||
| 234 | languages = "general.languages".strList() | ||
| 235 | ).takeUnless { | ||
| 236 | type == null && description == null && tags == null && languages == null | ||
| 237 | } | ||
| 238 | |||
| 239 | val architectureInfo = GgufMetadata.ArchitectureInfo( | ||
| 240 | architecture = arch, | ||
| 241 | fileType = "general.file_type".u32(), | ||
| 242 | vocabSize = "$arch.vocab_size".u32(), | ||
| 243 | finetune = "general.finetune".str(), | ||
| 244 | quantizationVersion = "general.quantization_version".u32() | ||
| 245 | ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null } | ||
| 246 | |||
| 247 | val baseModels = buildList { | ||
| 248 | val n = "general.base_model.count".u32() ?: 0 | ||
| 249 | for (i in 0 until n) { | ||
| 250 | fun k(s: String) = "general.base_model.$i.$s" | ||
| 251 | add( | ||
| 252 | GgufMetadata.BaseModelInfo( | ||
| 253 | name = k("name").str(), | ||
| 254 | author = k("author").str(), | ||
| 255 | version = k("version").str(), | ||
| 256 | organization = k("organization").str(), | ||
| 257 | url = k("url").str(), | ||
| 258 | doi = k("doi").str(), | ||
| 259 | uuid = k("uuid").str(), | ||
| 260 | repoUrl = k("repo_url").str(), | ||
| 261 | ) | ||
| 262 | ) | ||
| 263 | } | ||
| 264 | }.takeIf { it.isNotEmpty() } | ||
| 265 | |||
| 266 | val tokenizer = GgufMetadata.TokenizerInfo( | ||
| 267 | model = "tokenizer.ggml.model".str(), | ||
| 268 | bosTokenId = "tokenizer.ggml.bos_token_id".u32(), | ||
| 269 | eosTokenId = "tokenizer.ggml.eos_token_id".u32(), | ||
| 270 | unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(), | ||
| 271 | paddingTokenId = "tokenizer.ggml.padding_token_id".u32(), | ||
| 272 | addBosToken = "tokenizer.ggml.add_bos_token".bool(), | ||
| 273 | addEosToken = "tokenizer.ggml.add_eos_token".bool(), | ||
| 274 | chatTemplate = "tokenizer.chat_template".str() | ||
| 275 | ).takeUnless { model == null && bosTokenId == null && eosTokenId == null && | ||
| 276 | unknownTokenId == null && paddingTokenId == null && | ||
| 277 | addBosToken == null && addEosToken == null && chatTemplate == null | ||
| 278 | } | ||
| 279 | |||
| 280 | val dimensions = GgufMetadata.DimensionsInfo( | ||
| 281 | contextLength = "$arch.context_length".u32(), | ||
| 282 | embeddingSize = "$arch.embedding_length".u32(), | ||
| 283 | blockCount = "$arch.block_count".u32(), | ||
| 284 | feedForwardSize = "$arch.feed_forward_length".u32() | ||
| 285 | ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null } | ||
| 286 | |||
| 287 | val attention = GgufMetadata.AttentionInfo( | ||
| 288 | headCount = "$arch.attention.head_count".u32(), | ||
| 289 | headCountKv = "$arch.attention.head_count_kv".u32(), | ||
| 290 | keyLength = "$arch.attention.key_length".u32(), | ||
| 291 | valueLength = "$arch.attention.value_length".u32(), | ||
| 292 | layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(), | ||
| 293 | layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(), | ||
| 294 | ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null && | ||
| 295 | layerNormEpsilon == null && layerNormRmsEpsilon == null | ||
| 296 | } | ||
| 297 | |||
| 298 | val rope = GgufMetadata.RopeInfo( | ||
| 299 | frequencyBase = "$arch.rope.freq_base".f32(), | ||
| 300 | dimensionCount = "$arch.rope.dimension_count".u32(), | ||
| 301 | scalingType = "$arch.rope.scaling.type".str(), | ||
| 302 | scalingFactor = "$arch.rope.scaling.factor".f32(), | ||
| 303 | attnFactor = "$arch.rope.scaling.attn_factor".f32(), | ||
| 304 | originalContextLength = "$arch.rope.scaling.original_context_length".u32(), | ||
| 305 | finetuned = "$arch.rope.scaling.finetuned".bool() | ||
| 306 | ).takeUnless { frequencyBase == null && dimensionCount == null && | ||
| 307 | scalingType == null && scalingFactor == null && attnFactor == null && | ||
| 308 | originalContextLength == null && finetuned == null | ||
| 309 | } | ||
| 310 | |||
| 311 | val experts = GgufMetadata.ExpertsInfo( | ||
| 312 | count = "$arch.expert_count".u32(), | ||
| 313 | usedCount = "$arch.expert_used_count".u32() | ||
| 314 | ).takeUnless { count == null && usedCount == null } | ||
| 315 | |||
| 316 | return GgufMetadata( | ||
| 317 | version = version, | ||
| 318 | tensorCount = tensorCnt, | ||
| 319 | kvCount = kvCnt, | ||
| 320 | basic = basic, | ||
| 321 | author = author, | ||
| 322 | additional = additional, | ||
| 323 | architecture = architectureInfo, | ||
| 324 | baseModels = baseModels, | ||
| 325 | tokenizer = tokenizer, | ||
| 326 | dimensions = dimensions, | ||
| 327 | attention = attention, | ||
| 328 | rope = rope, | ||
| 329 | experts = experts | ||
| 330 | ) | ||
| 331 | } | ||
| 332 | |||
| 333 | /** | ||
| 334 | * Recursively parses a metadata value of the given type from the input stream. | ||
| 335 | * @param input The input stream positioned at the start of the value. | ||
| 336 | * @param type The metadata value type to parse. | ||
| 337 | */ | ||
| 338 | private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) { | ||
| 339 | MetadataType.UINT8 -> { | ||
| 340 | // 1-byte unsigned integer | ||
| 341 | val byteVal = input.read() | ||
| 342 | if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.") | ||
| 343 | MetadataValue.UInt8(byteVal.toUByte()) | ||
| 344 | } | ||
| 345 | MetadataType.INT8 -> { | ||
| 346 | // 1-byte signed integer | ||
| 347 | val byteVal = input.read() | ||
| 348 | if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.") | ||
| 349 | MetadataValue.Int8(byteVal.toByte()) | ||
| 350 | } | ||
| 351 | MetadataType.UINT16 -> { | ||
| 352 | // 2-byte unsigned integer (little-endian) | ||
| 353 | val bytes = ByteArray(2) | ||
| 354 | if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.") | ||
| 355 | // Combine two bytes (little-endian) into an unsigned 16-bit value | ||
| 356 | val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) | ||
| 357 | MetadataValue.UInt16(u16.toUShort()) | ||
| 358 | } | ||
| 359 | MetadataType.INT16 -> { | ||
| 360 | // 2-byte signed integer (little-endian) | ||
| 361 | val bytes = ByteArray(2) | ||
| 362 | if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.") | ||
| 363 | // Combine to 16-bit and interpret as signed | ||
| 364 | val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) | ||
| 365 | MetadataValue.Int16(i16.toShort()) | ||
| 366 | } | ||
| 367 | MetadataType.UINT32 -> { | ||
| 368 | // 4-byte unsigned integer (little-endian) | ||
| 369 | val bytes = ByteArray(4) | ||
| 370 | if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.") | ||
| 371 | // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt | ||
| 372 | val u32 = (bytes[3].toLong() and 0xFFL shl 24) or | ||
| 373 | (bytes[2].toLong() and 0xFFL shl 16) or | ||
| 374 | (bytes[1].toLong() and 0xFFL shl 8) or | ||
| 375 | (bytes[0].toLong() and 0xFFL) | ||
| 376 | MetadataValue.UInt32(u32.toUInt()) | ||
| 377 | } | ||
| 378 | MetadataType.INT32 -> { | ||
| 379 | // 4-byte signed integer (little-endian) | ||
| 380 | val bytes = ByteArray(4) | ||
| 381 | if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.") | ||
| 382 | // Combine four bytes into a 32-bit signed int | ||
| 383 | val i32 = (bytes[3].toInt() and 0xFF shl 24) or | ||
| 384 | (bytes[2].toInt() and 0xFF shl 16) or | ||
| 385 | (bytes[1].toInt() and 0xFF shl 8) or | ||
| 386 | (bytes[0].toInt() and 0xFF) | ||
| 387 | MetadataValue.Int32(i32) | ||
| 388 | } | ||
| 389 | MetadataType.FLOAT32 -> { | ||
| 390 | // 4-byte IEEE 754 float (little-endian) | ||
| 391 | val bytes = ByteArray(4) | ||
| 392 | if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.") | ||
| 393 | // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float | ||
| 394 | val bits = (bytes[3].toInt() and 0xFF shl 24) or | ||
| 395 | (bytes[2].toInt() and 0xFF shl 16) or | ||
| 396 | (bytes[1].toInt() and 0xFF shl 8) or | ||
| 397 | (bytes[0].toInt() and 0xFF) | ||
| 398 | val floatVal = Float.fromBits(bits) | ||
| 399 | MetadataValue.Float32(floatVal) | ||
| 400 | } | ||
| 401 | MetadataType.BOOL -> { | ||
| 402 | // 1-byte boolean (0 = false, 1 = true) | ||
| 403 | val byteVal = input.read() | ||
| 404 | if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.") | ||
| 405 | if (byteVal != 0 && byteVal != 1) { | ||
| 406 | throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).") | ||
| 407 | } | ||
| 408 | MetadataValue.Bool(byteVal != 0) | ||
| 409 | } | ||
| 410 | MetadataType.STRING -> { | ||
| 411 | // UTF-8 string (length-prefixed with 8-byte length) | ||
| 412 | val str = readString(input) | ||
| 413 | MetadataValue.StringVal(str) | ||
| 414 | } | ||
| 415 | MetadataType.ARRAY -> { | ||
| 416 | val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) | ||
| 417 | val len = readLittleLong(input) | ||
| 418 | val count = len.toInt() | ||
| 419 | |||
| 420 | if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) { | ||
| 421 | // fast‑forward without allocation | ||
| 422 | repeat(count) { skipValue(input, elemType) } | ||
| 423 | MetadataValue.StringVal("Array($elemType, $count items) /* summarised */") | ||
| 424 | } else { | ||
| 425 | val list = ArrayList<MetadataValue>(count) | ||
| 426 | repeat(count) { list += parseValue(input, elemType) } | ||
| 427 | MetadataValue.ArrayVal(elemType, list) | ||
| 428 | } | ||
| 429 | } | ||
| 430 | MetadataType.UINT64 -> { | ||
| 431 | // 8-byte unsigned integer (little-endian) | ||
| 432 | val bytes = ByteArray(8) | ||
| 433 | if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.") | ||
| 434 | // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range. | ||
| 435 | val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or | ||
| 436 | (bytes[6].toULong() and 0xFFuL shl 48) or | ||
| 437 | (bytes[5].toULong() and 0xFFuL shl 40) or | ||
| 438 | (bytes[4].toULong() and 0xFFuL shl 32) or | ||
| 439 | (bytes[3].toULong() and 0xFFuL shl 24) or | ||
| 440 | (bytes[2].toULong() and 0xFFuL shl 16) or | ||
| 441 | (bytes[1].toULong() and 0xFFuL shl 8) or | ||
| 442 | (bytes[0].toULong() and 0xFFuL) | ||
| 443 | MetadataValue.UInt64(u64) | ||
| 444 | } | ||
| 445 | MetadataType.INT64 -> { | ||
| 446 | // 8-byte signed integer (little-endian) | ||
| 447 | val bytes = ByteArray(8) | ||
| 448 | if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.") | ||
| 449 | // Combine 8 bytes into a signed 64-bit value (Long) | ||
| 450 | val i64 = (bytes[7].toLong() and 0xFFL shl 56) or | ||
| 451 | (bytes[6].toLong() and 0xFFL shl 48) or | ||
| 452 | (bytes[5].toLong() and 0xFFL shl 40) or | ||
| 453 | (bytes[4].toLong() and 0xFFL shl 32) or | ||
| 454 | (bytes[3].toLong() and 0xFFL shl 24) or | ||
| 455 | (bytes[2].toLong() and 0xFFL shl 16) or | ||
| 456 | (bytes[1].toLong() and 0xFFL shl 8) or | ||
| 457 | (bytes[0].toLong() and 0xFFL) | ||
| 458 | MetadataValue.Int64(i64) | ||
| 459 | } | ||
| 460 | MetadataType.FLOAT64 -> { | ||
| 461 | // 8-byte IEEE 754 double (little-endian) | ||
| 462 | val bytes = ByteArray(8) | ||
| 463 | if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.") | ||
| 464 | // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double | ||
| 465 | val bits = (bytes[7].toLong() and 0xFFL shl 56) or | ||
| 466 | (bytes[6].toLong() and 0xFFL shl 48) or | ||
| 467 | (bytes[5].toLong() and 0xFFL shl 40) or | ||
| 468 | (bytes[4].toLong() and 0xFFL shl 32) or | ||
| 469 | (bytes[3].toLong() and 0xFFL shl 24) or | ||
| 470 | (bytes[2].toLong() and 0xFFL shl 16) or | ||
| 471 | (bytes[1].toLong() and 0xFFL shl 8) or | ||
| 472 | (bytes[0].toLong() and 0xFFL) | ||
| 473 | val doubleVal = Double.fromBits(bits) | ||
| 474 | MetadataValue.Float64(doubleVal) | ||
| 475 | } | ||
| 476 | } | ||
| 477 | |||
| 478 | |||
| 479 | private fun <T> T?.takeUnless(check: T.() -> Boolean): T? = | ||
| 480 | this?.takeIf { !it.check() } | ||
| 481 | |||
| 482 | /** Helper: Skip a value in the stream without storing it (still maintains pointer). */ | ||
| 483 | private fun skipValue(input: InputStream, type: MetadataType) { | ||
| 484 | when (type) { | ||
| 485 | MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1) | ||
| 486 | MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2) | ||
| 487 | MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4) | ||
| 488 | MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8) | ||
| 489 | MetadataType.STRING -> { | ||
| 490 | val len = readLittleLong(input); input.skipFully(len) | ||
| 491 | } | ||
| 492 | MetadataType.ARRAY -> { | ||
| 493 | val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) | ||
| 494 | val len = readLittleLong(input) | ||
| 495 | repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip | ||
| 496 | } | ||
| 497 | } | ||
| 498 | } | ||
| 499 | |||
| 500 | /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */ | ||
| 501 | private fun readLittleLong(input: InputStream): Long { | ||
| 502 | val bytes = ByteArray(8) | ||
| 503 | input.readFully(bytes) | ||
| 504 | |||
| 505 | // Combine 8 bytes into a 64-bit value (Little Endian). | ||
| 506 | // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement). | ||
| 507 | // In our context (lengths/counts), such extremely large values are not expected. | ||
| 508 | return (bytes[7].toLong() and 0xFFL shl 56) or | ||
| 509 | (bytes[6].toLong() and 0xFFL shl 48) or | ||
| 510 | (bytes[5].toLong() and 0xFFL shl 40) or | ||
| 511 | (bytes[4].toLong() and 0xFFL shl 32) or | ||
| 512 | (bytes[3].toLong() and 0xFFL shl 24) or | ||
| 513 | (bytes[2].toLong() and 0xFFL shl 16) or | ||
| 514 | (bytes[1].toLong() and 0xFFL shl 8) or | ||
| 515 | (bytes[0].toLong() and 0xFFL) | ||
| 516 | } | ||
| 517 | |||
| 518 | /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ | ||
| 519 | private fun readString(input: InputStream): String = | ||
| 520 | // Read 8-byte little-endian length (number of bytes in the string). | ||
| 521 | readLittleLong(input).let { len -> | ||
| 522 | if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") | ||
| 523 | |||
| 524 | // Read the UTF-8 bytes of the given length. | ||
| 525 | ByteArray(len.toInt()).let { | ||
| 526 | if (it.isNotEmpty()) input.readFully(it) | ||
| 527 | String(it, Charsets.UTF_8) | ||
| 528 | } | ||
| 529 | } | ||
| 530 | |||
| 531 | /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ | ||
| 532 | private fun littleEndianBytesToInt(bytes: ByteArray): Int = | ||
| 533 | // Note: assumes bytes length is 4. | ||
| 534 | (bytes[3].toInt() and 0xFF shl 24) or | ||
| 535 | (bytes[2].toInt() and 0xFF shl 16) or | ||
| 536 | (bytes[1].toInt() and 0xFF shl 8) or | ||
| 537 | (bytes[0].toInt() and 0xFF) | ||
| 538 | |||
| 539 | /** | ||
| 540 | * Robust skip that works the same on JDK 11 and Android’s desugared runtime. | ||
| 541 | * | ||
| 542 | * @param n Number of bytes to advance in the stream. | ||
| 543 | * @throws IOException on premature EOF. | ||
| 544 | */ | ||
| 545 | private fun InputStream.skipFully(n: Long) { | ||
| 546 | var remaining = n | ||
| 547 | val scratch = ByteArray(8192) // read‑and‑toss buffer | ||
| 548 | while (remaining > 0) { | ||
| 549 | val skipped = skip(remaining) | ||
| 550 | when { | ||
| 551 | skipped > 0 -> remaining -= skipped // normal fast path | ||
| 552 | skipped == 0L -> { | ||
| 553 | // fallback: read and discard | ||
| 554 | val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt()) | ||
| 555 | if (read == -1) throw IOException("EOF while skipping $n bytes") | ||
| 556 | remaining -= read | ||
| 557 | } | ||
| 558 | else -> throw IOException("Skip returned negative value") | ||
| 559 | } | ||
| 560 | } | ||
| 561 | } | ||
| 562 | |||
| 563 | /** | ||
| 564 | * Extension that keeps reading until the requested number of bytes are filled. | ||
| 565 | * Falls back to `read()` when `skip()` returns 0, which happens on some Android | ||
| 566 | * streams. | ||
| 567 | * | ||
| 568 | * @param buf Destination buffer. | ||
| 569 | * @param len Number of bytes to fill (defaults to `buf.size`). | ||
| 570 | * @throws IOException on premature EOF. | ||
| 571 | */ | ||
| 572 | private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) { | ||
| 573 | var off = 0 | ||
| 574 | while (off < len) { | ||
| 575 | val n = read(buf, off, len - off) | ||
| 576 | if (n == -1) throw IOException("EOF after $off of $len bytes") | ||
| 577 | off += n | ||
| 578 | } | ||
| 579 | } | ||
| 580 | |||
| 581 | /** | ||
| 582 | * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array. | ||
| 583 | * This is used for small fixed‑length reads (e.g. 4‑byte type codes). | ||
| 584 | * | ||
| 585 | * @throws IOException on premature EOF. | ||
| 586 | */ | ||
| 587 | private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also { | ||
| 588 | if (read(it) != n) throw IOException("Unexpected EOF") | ||
| 589 | } | ||
| 590 | } | ||
