diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/examples/llama.android/lib/src/main/java | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/examples/llama.android/lib/src/main/java')
7 files changed, 1287 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt new file mode 100644 index 0000000..b72a24e --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt @@ -0,0 +1,14 @@ +package com.arm.aichat + +import android.content.Context +import com.arm.aichat.internal.InferenceEngineImpl + +/** + * Main entry point for Arm's AI Chat library. + */ +object AiChat { + /** + * Get the inference engine single instance. + */ + fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context) +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt new file mode 100644 index 0000000..26c1668 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt @@ -0,0 +1,89 @@ +package com.arm.aichat + +import com.arm.aichat.InferenceEngine.State +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.StateFlow + +/** + * Interface defining the core LLM inference operations. + */ +interface InferenceEngine { + /** + * Current state of the inference engine + */ + val state: StateFlow<State> + + /** + * Load a model from the given path. + * + * @throws UnsupportedArchitectureException if model architecture not supported + */ + suspend fun loadModel(pathToModel: String) + + /** + * Sends a system prompt to the loaded model + */ + suspend fun setSystemPrompt(systemPrompt: String) + + /** + * Sends a user prompt to the loaded model and returns a Flow of generated tokens. + */ + fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> + + /** + * Runs a benchmark with the specified parameters. + */ + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String + + /** + * Unloads the currently loaded model. + */ + fun cleanUp() + + /** + * Cleans up resources when the engine is no longer needed. + */ + fun destroy() + + /** + * States of the inference engine + */ + sealed class State { + object Uninitialized : State() + object Initializing : State() + object Initialized : State() + + object LoadingModel : State() + object UnloadingModel : State() + object ModelReady : State() + + object Benchmarking : State() + object ProcessingSystemPrompt : State() + object ProcessingUserPrompt : State() + + object Generating : State() + + data class Error(val exception: Exception) : State() + } + + companion object { + const val DEFAULT_PREDICT_LENGTH = 1024 + } +} + +val State.isUninterruptible + get() = this is State.Initializing || + this is State.LoadingModel || + this is State.UnloadingModel || + this is State.Benchmarking || + this is State.ProcessingSystemPrompt || + this is State.ProcessingUserPrompt + +val State.isModelLoaded: Boolean + get() = this is State.ModelReady || + this is State.Benchmarking || + this is State.ProcessingSystemPrompt || + this is State.ProcessingUserPrompt || + this is State.Generating + +class UnsupportedArchitectureException : Exception() diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt new file mode 100644 index 0000000..2f15eef --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt @@ -0,0 +1,61 @@ +package com.arm.aichat.gguf + +import kotlin.collections.get + + +/** + * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`). + * The `label` matches what llama‑cli prints. + */ +enum class FileType(val code: Int, val label: String) { + ALL_F32(0, "all F32"), + MOSTLY_F16(1, "F16"), + MOSTLY_Q4_0(2, "Q4_0"), + MOSTLY_Q4_1(3, "Q4_1"), + // 4 removed + MOSTLY_Q8_0(7, "Q8_0"), + MOSTLY_Q5_0(8, "Q5_0"), + MOSTLY_Q5_1(9, "Q5_1"), + + /* K‑quants ------------------------------------------------------------ */ + MOSTLY_Q2_K (10, "Q2_K - Medium"), + MOSTLY_Q3_K_S (11, "Q3_K - Small"), + MOSTLY_Q3_K_M (12, "Q3_K - Medium"), + MOSTLY_Q3_K_L (13, "Q3_K - Large"), + MOSTLY_Q4_K_S (14, "Q4_K - Small"), + MOSTLY_Q4_K_M (15, "Q4_K - Medium"), + MOSTLY_Q5_K_S (16, "Q5_K - Small"), + MOSTLY_Q5_K_M (17, "Q5_K - Medium"), + MOSTLY_Q6_K (18, "Q6_K"), + + /* IQ quants ----------------------------------------------------------- */ + MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"), + MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"), + MOSTLY_Q2_K_S (21, "Q2_K - Small"), + MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"), + MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"), + MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"), + MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"), + MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"), + MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"), + MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"), + MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"), + MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"), + MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"), + + /* BF16 & Ternary ------------------------------------------------------ */ + MOSTLY_BF16 (32, "BF16"), + MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"), + MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"), + + /* Special flag -------------------------------------------------------- */ + GUESSED(1024, "(guessed)"), + + UNKNOWN(-1, "unknown"); + + companion object { + private val map = entries.associateBy(FileType::code) + + fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt new file mode 100644 index 0000000..5e1971a --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt @@ -0,0 +1,132 @@ +package com.arm.aichat.gguf + +import java.io.IOException + + +/** + * Structured metadata of GGUF + */ +data class GgufMetadata( + // Basic file info + val version: GgufVersion, + val tensorCount: Long, + val kvCount: Long, + + // General info + val basic: BasicInfo, + val author: AuthorInfo? = null, + val additional: AdditionalInfo? = null, + val architecture: ArchitectureInfo? = null, + val baseModels: List<BaseModelInfo>? = null, + val tokenizer: TokenizerInfo? = null, + + // Derivative info + val dimensions: DimensionsInfo? = null, + val attention: AttentionInfo? = null, + val rope: RopeInfo? = null, + val experts: ExpertsInfo? = null +) { + enum class GgufVersion(val code: Int, val label: String) { + /** First public draft; little‑endian only, no alignment key. */ + LEGACY_V1(1, "Legacy v1"), + + /** Added split‑file support and some extra metadata keys. */ + EXTENDED_V2(2, "Extended v2"), + + /** Current spec: endian‑aware, mandatory alignment, fully validated. */ + VALIDATED_V3(3, "Validated v3"); + + companion object { + fun fromCode(code: Int): GgufVersion = + entries.firstOrNull { it.code == code } + ?: throw IOException("Unknown GGUF version code $code") + } + + override fun toString(): String = "$label (code=$code)" + } + + data class BasicInfo( + val uuid: String? = null, + val name: String? = null, + val nameLabel: String? = null, + val sizeLabel: String? = null, // Size label like "7B" + ) + + data class AuthorInfo( + val organization: String? = null, + val author: String? = null, + val doi: String? = null, + val url: String? = null, + val repoUrl: String? = null, + val license: String? = null, + val licenseLink: String? = null, + ) + + data class AdditionalInfo( + val type: String? = null, + val description: String? = null, + val tags: List<String>? = null, + val languages: List<String>? = null, + ) + + data class ArchitectureInfo( + val architecture: String? = null, + val fileType: Int? = null, + val vocabSize: Int? = null, + val finetune: String? = null, + val quantizationVersion: Int? = null, + ) + + data class BaseModelInfo( + val name: String? = null, + val author: String? = null, + val version: String? = null, + val organization: String? = null, + val url: String? = null, + val doi: String? = null, + val uuid: String? = null, + val repoUrl: String? = null, + ) + + data class TokenizerInfo( + val model: String? = null, + val bosTokenId: Int? = null, + val eosTokenId: Int? = null, + val unknownTokenId: Int? = null, + val paddingTokenId: Int? = null, + val addBosToken: Boolean? = null, + val addEosToken: Boolean? = null, + val chatTemplate: String? = null, + ) + + data class DimensionsInfo( + val contextLength: Int? = null, + val embeddingSize: Int? = null, + val blockCount: Int? = null, + val feedForwardSize: Int? = null, + ) + + data class AttentionInfo( + val headCount: Int? = null, + val headCountKv: Int? = null, + val keyLength: Int? = null, + val valueLength: Int? = null, + val layerNormEpsilon: Float? = null, + val layerNormRmsEpsilon: Float? = null, + ) + + data class RopeInfo( + val frequencyBase: Float? = null, + val dimensionCount: Int? = null, + val scalingType: String? = null, + val scalingFactor: Float? = null, + val attnFactor: Float? = null, + val originalContextLength: Int? = null, + val finetuned: Boolean? = null, + ) + + data class ExpertsInfo( + val count: Int? = null, + val usedCount: Int? = null, + ) +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt new file mode 100644 index 0000000..264a6c0 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt @@ -0,0 +1,77 @@ +package com.arm.aichat.gguf + +import android.content.Context +import android.net.Uri +import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl +import java.io.File +import java.io.IOException +import java.io.InputStream + +/** + * Interface for reading GGUF metadata from model files. + * Use `GgufMetadataReader.create()` to get an instance. + */ +interface GgufMetadataReader { + /** + * Reads the magic number from the specified file path. + * + * @param file Java File to the GGUF file with absolute path + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(file: File): Boolean + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining [android.content.ContentProvider] + * @param uri Uri to the GGUF file provided by [android.content.ContentProvider] + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean + + /** + * Reads and parses GGUF metadata from the specified file path. + * + * @param input the [InputStream] obtained from a readable file or content + * @return Structured metadata extracted from the file + * @throws IOException if file is damaged or cannot be read + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun readStructuredMetadata(input: InputStream): GgufMetadata + + companion object { + private val DEFAULT_SKIP_KEYS = setOf( + "tokenizer.chat_template", + "tokenizer.ggml.scores", + "tokenizer.ggml.tokens", + "tokenizer.ggml.token_type" + ) + + /** + * Creates a default GgufMetadataReader instance + */ + fun create(): GgufMetadataReader = GgufMetadataReaderImpl( + skipKeys = DEFAULT_SKIP_KEYS, + arraySummariseThreshold = 1_000 + ) + + /** + * Creates a GgufMetadataReader with custom configuration + * + * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map) + * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised; + * If -1, never summarise. + */ + fun create( + skipKeys: Set<String> = DEFAULT_SKIP_KEYS, + arraySummariseThreshold: Int = 1_000 + ): GgufMetadataReader = GgufMetadataReaderImpl( + skipKeys = skipKeys, + arraySummariseThreshold = arraySummariseThreshold + ) + } +} + +class InvalidFileFormatException : IOException() diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt new file mode 100644 index 0000000..a293796 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt @@ -0,0 +1,324 @@ +package com.arm.aichat.internal + +import android.content.Context +import android.util.Log +import com.arm.aichat.InferenceEngine +import com.arm.aichat.UnsupportedArchitectureException +import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance +import dalvik.annotation.optimization.FastNative +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import java.io.File +import java.io.IOException + +/** + * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. + * + * This class implements a singleton pattern for managing the lifecycle of a single LLM instance. + * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety + * with the underlying C++ native code. + * + * The typical usage flow is: + * 1. Get instance via [getInstance] + * 2. Load a model with [loadModel] + * 3. Send prompts with [sendUserPrompt] + * 4. Generate responses as token streams + * 5. Perform [cleanUp] when done with a model + * 6. Properly [destroy] when completely done + * + * State transitions are managed automatically and validated at each operation. + * + * @see ai_chat.cpp for the native implementation details + */ +internal class InferenceEngineImpl private constructor( + private val nativeLibDir: String +) : InferenceEngine { + + companion object { + private val TAG = InferenceEngineImpl::class.java.simpleName + + @Volatile + private var instance: InferenceEngine? = null + + /** + * Create or obtain [InferenceEngineImpl]'s single instance. + * + * @param Context for obtaining native library directory + * @throws IllegalArgumentException if native library path is invalid + * @throws UnsatisfiedLinkError if library failed to load + */ + internal fun getInstance(context: Context) = + instance ?: synchronized(this) { + val nativeLibDir = context.applicationInfo.nativeLibraryDir + require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" } + + try { + Log.i(TAG, "Instantiating InferenceEngineImpl,,,") + InferenceEngineImpl(nativeLibDir).also { instance = it } + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load native library from $nativeLibDir", e) + throw e + } + } + } + + /** + * JNI methods + * @see ai_chat.cpp + */ + @FastNative + private external fun init(nativeLibDir: String) + + @FastNative + private external fun load(modelPath: String): Int + + @FastNative + private external fun prepare(): Int + + @FastNative + private external fun systemInfo(): String + + @FastNative + private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String + + @FastNative + private external fun processSystemPrompt(systemPrompt: String): Int + + @FastNative + private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int + + @FastNative + private external fun generateNextToken(): String? + + @FastNative + private external fun unload() + + @FastNative + private external fun shutdown() + + private val _state = + MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized) + override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow() + + private var _readyForSystemPrompt = false + @Volatile + private var _cancelGeneration = false + + /** + * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations + */ + @OptIn(ExperimentalCoroutinesApi::class) + private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1) + private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob()) + + init { + llamaScope.launch { + try { + check(_state.value is InferenceEngine.State.Uninitialized) { + "Cannot load native library in ${_state.value.javaClass.simpleName}!" + } + _state.value = InferenceEngine.State.Initializing + Log.i(TAG, "Loading native library...") + System.loadLibrary("ai-chat") + init(nativeLibDir) + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") + + } catch (e: Exception) { + Log.e(TAG, "Failed to load native library", e) + throw e + } + } + } + + /** + * Load the LLM + */ + override suspend fun loadModel(pathToModel: String) = + withContext(llamaDispatcher) { + check(_state.value is InferenceEngine.State.Initialized) { + "Cannot load model in ${_state.value.javaClass.simpleName}!" + } + + try { + Log.i(TAG, "Checking access to model file... \n$pathToModel") + File(pathToModel).let { + require(it.exists()) { "File not found" } + require(it.isFile) { "Not a valid file" } + require(it.canRead()) { "Cannot read file" } + } + + Log.i(TAG, "Loading model... \n$pathToModel") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.LoadingModel + load(pathToModel).let { + // TODO-han.yin: find a better way to pass other error codes + if (it != 0) throw UnsupportedArchitectureException() + } + prepare().let { + if (it != 0) throw IOException("Failed to prepare resources") + } + Log.i(TAG, "Model loaded!") + _readyForSystemPrompt = true + + _cancelGeneration = false + _state.value = InferenceEngine.State.ModelReady + } catch (e: Exception) { + Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e) + _state.value = InferenceEngine.State.Error(e) + throw e + } + } + + /** + * Process the plain text system prompt + * + * TODO-han.yin: return error code if system prompt not correct processed? + */ + override suspend fun setSystemPrompt(prompt: String) = + withContext(llamaDispatcher) { + require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } + check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } + check(_state.value is InferenceEngine.State.ModelReady) { + "Cannot process system prompt in ${_state.value.javaClass.simpleName}!" + } + + Log.i(TAG, "Sending system prompt...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.ProcessingSystemPrompt + processSystemPrompt(prompt).let { result -> + if (result != 0) { + RuntimeException("Failed to process system prompt: $result").also { + _state.value = InferenceEngine.State.Error(it) + throw it + } + } + } + Log.i(TAG, "System prompt processed! Awaiting user prompt...") + _state.value = InferenceEngine.State.ModelReady + } + + /** + * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] + */ + override fun sendUserPrompt( + message: String, + predictLength: Int, + ): Flow<String> = flow { + require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } + check(_state.value is InferenceEngine.State.ModelReady) { + "User prompt discarded due to: ${_state.value.javaClass.simpleName}" + } + + try { + Log.i(TAG, "Sending user prompt...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.ProcessingUserPrompt + + processUserPrompt(message, predictLength).let { result -> + if (result != 0) { + Log.e(TAG, "Failed to process user prompt: $result") + return@flow + } + } + + Log.i(TAG, "User prompt processed. Generating assistant prompt...") + _state.value = InferenceEngine.State.Generating + while (!_cancelGeneration) { + generateNextToken()?.let { utf8token -> + if (utf8token.isNotEmpty()) emit(utf8token) + } ?: break + } + if (_cancelGeneration) { + Log.i(TAG, "Assistant generation aborted per requested.") + } else { + Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") + } + _state.value = InferenceEngine.State.ModelReady + } catch (e: CancellationException) { + Log.i(TAG, "Assistant generation's flow collection cancelled.") + _state.value = InferenceEngine.State.ModelReady + throw e + } catch (e: Exception) { + Log.e(TAG, "Error during generation!", e) + _state.value = InferenceEngine.State.Error(e) + throw e + } + }.flowOn(llamaDispatcher) + + /** + * Benchmark the model + */ + override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = + withContext(llamaDispatcher) { + check(_state.value is InferenceEngine.State.ModelReady) { + "Benchmark request discarded due to: $state" + } + Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") + _readyForSystemPrompt = false // Just to be safe + _state.value = InferenceEngine.State.Benchmarking + benchModel(pp, tg, pl, nr).also { + _state.value = InferenceEngine.State.ModelReady + } + } + + /** + * Unloads the model and frees resources, or reset error states + */ + override fun cleanUp() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { + when (val state = _state.value) { + is InferenceEngine.State.ModelReady -> { + Log.i(TAG, "Unloading model and free resources...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.UnloadingModel + + unload() + + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "Model unloaded!") + Unit + } + + is InferenceEngine.State.Error -> { + Log.i(TAG, "Resetting error states...") + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "States reset!") + Unit + } + + else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") + } + } + } + + /** + * Cancel all ongoing coroutines and free GGML backends + */ + override fun destroy() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { + _readyForSystemPrompt = false + when(_state.value) { + is InferenceEngine.State.Uninitialized -> {} + is InferenceEngine.State.Initialized -> shutdown() + else -> { unload(); shutdown() } + } + } + llamaScope.cancel() + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt new file mode 100644 index 0000000..bf250ac --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt @@ -0,0 +1,590 @@ +package com.arm.aichat.internal.gguf + +import android.content.Context +import android.net.Uri +import com.arm.aichat.gguf.GgufMetadata +import com.arm.aichat.gguf.GgufMetadataReader +import com.arm.aichat.gguf.InvalidFileFormatException +import java.io.File +import java.io.IOException +import java.io.InputStream + + +/** + * Utility class to read GGUF model files and extract metadata key-value pairs. + * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data. + */ +internal class GgufMetadataReaderImpl( + private val skipKeys: Set<String>, + private val arraySummariseThreshold: Int, +) : GgufMetadataReader { + companion object { + private const val ARCH_LLAMA = "llama" + } + + /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */ + enum class MetadataType(val code: Int) { + UINT8(0), INT8(1), UINT16(2), INT16(3), + UINT32(4), INT32(5), FLOAT32(6), BOOL(7), + STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); + companion object { + private val codeMap = entries.associateBy(MetadataType::code) + fun fromCode(code: Int): MetadataType = codeMap[code] + ?: throw IOException("Unknown metadata value type code: $code") + } + } + + /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */ + sealed class MetadataValue { + data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int + data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int + data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian) + data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian) + data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian) + data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian) + data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float + data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true) + data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed) + data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue() + data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian) + data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian) + data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double + } + + /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */ + private fun MetadataValue.toPrimitive(): Any = when (this) { + is MetadataValue.UInt8 -> value + is MetadataValue.Int8 -> value + is MetadataValue.UInt16 -> value + is MetadataValue.Int16 -> value + is MetadataValue.UInt32 -> value + is MetadataValue.Int32 -> value + is MetadataValue.Float32 -> value + is MetadataValue.Bool -> value + is MetadataValue.StringVal -> value + is MetadataValue.UInt64 -> value + is MetadataValue.Int64 -> value + is MetadataValue.Float64 -> value + is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } + } + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(file: File): Boolean = + file.inputStream().buffered().use { ensureMagic(it) } + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean = + context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true + + /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */ + private fun ensureMagic(input: InputStream): Boolean = + ByteArray(4).let { + if (input.read(it) != 4) throw IOException("Not a valid file!") + it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" + } + + /** + * High‑level entry point: parses a `.gguf` file on disk and returns the fully + * populated [GgufMetadata] tree. + * + * Steps performed internally: + * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version). + * 2. Streams through the key‑value section, skipping large blobs if the key + * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold]. + * 3. Converts the resulting raw map into strongly‑typed sub‑structures + * (basic info, tokenizer, rope, etc.). + * + * The method is STREAMING‑ONLY: tensors are never mapped or loaded into + * memory, so even multi‑GB model files can be processed in < 50 ms. + * + * @param path Absolute or relative filesystem path to a `.gguf` file. + * @return A [GgufMetadata] instance containing all recognised metadata plus + * an `allMetadata` map with any keys that were not given a dedicated + * field. + * @throws IOException if the file is not GGUF, the version is unsupported, + * or the metadata block is truncated / corrupt. + */ + override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata { + // ── 1. header ────────────────────────────────────────────────────────── + // throws on mismatch + val version = ensureMagicAndVersion(input) + val tensorCount = readLittleLong(input) + val kvCount = readLittleLong(input) + + // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── + val meta = readMetaMap(input, kvCount) // <String, MetadataValue> + + // ── 3. build structured object ──────────────────────────────────────── + return buildStructured(meta, version, tensorCount, kvCount) + } + + /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ + private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { + if (!ensureMagic(input)) throw InvalidFileFormatException() + return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) + } + + /** + * Read an unsigned 32‑bit little‑endian integer. + * + * @throws IOException if fewer than four bytes are available. + */ + private fun readLEUInt32(input: InputStream): Int { + val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() + if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") + return (b3 and 0xFF shl 24) or + (b2 and 0xFF shl 16) or + (b1 and 0xFF shl 8) or + (b0 and 0xFF) + } + + /** + * Low‑level helper that reads the entire “key-value” section from the current + * stream position. + * + * @param input Open stream positioned JUST AFTER the header. + * @param kvCnt Number of key‑value pairs (taken from the header). + * @return Mutable map with one [MetadataValue] for every key that is NOT skipped. + * + * The function honours [skipKeys] and [arraySummariseThreshold] by invoking + * [skipValue] or [parseValue] accordingly. + */ + private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> = + mutableMapOf<String, MetadataValue>().apply { + repeat(kvCnt.toInt()) { + val key = readString(input) + val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + if (key in skipKeys) { + skipValue(input, valueT) + } else { + this[key] = parseValue(input, valueT) + } + } + } + + /** + * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed + * [GgufMetadata] tree used by the rest of the app. + * + * Only the keys listed in the spec are copied into dedicated data classes; + * everything else is preserved in `GgufMetadata.allMetadata`. + * + * @param m Raw key/value map. + * @param version GGUF file‑format version (enum). + * @param tensorCnt Number of tensors (from the header). + * @param kvCnt Total metadata pair count (from the header). + */ + private fun buildStructured( + m: Map<String, MetadataValue>, + version: GgufMetadata.GgufVersion, + tensorCnt: Long, + kvCnt: Long + ): GgufMetadata { + // ---------- helpers ---------- + fun String.str() = (m[this] as? MetadataValue.StringVal)?.value + fun String.bool() = (m[this] as? MetadataValue.Bool)?.value + fun String.i32() = (m[this] as? MetadataValue.Int32)?.value + fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt() + fun String.f32() = (m[this] as? MetadataValue.Float32)?.value + fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat() + fun String.strList(): List<String>? = + (m[this] as? MetadataValue.ArrayVal) + ?.elements + ?.mapNotNull { (it as? MetadataValue.StringVal)?.value } + + val arch = "general.architecture".str() ?: ARCH_LLAMA + + // -------------- populate sections ---------------- + val basic = GgufMetadata.BasicInfo( + uuid = "general.uuid".str(), + name = "general.basename".str(), + nameLabel = "general.name".str(), + sizeLabel = "general.size_label".str() + ) + + val author = GgufMetadata.AuthorInfo( + organization = "general.organization".str(), + author = "general.author".str(), + doi = "general.doi".str(), + url = "general.url".str(), + repoUrl = "general.repo_url".str(), + license = "general.license".str(), + licenseLink = "general.license.link".str() + ).takeUnless { + organization == null && author == null && doi == null && + url == null && repoUrl == null && license == null && licenseLink == null + } + + val additional = GgufMetadata.AdditionalInfo( + type = "general.type".str(), + description = "general.description".str(), + tags = "general.tags".strList(), + languages = "general.languages".strList() + ).takeUnless { + type == null && description == null && tags == null && languages == null + } + + val architectureInfo = GgufMetadata.ArchitectureInfo( + architecture = arch, + fileType = "general.file_type".u32(), + vocabSize = "$arch.vocab_size".u32(), + finetune = "general.finetune".str(), + quantizationVersion = "general.quantization_version".u32() + ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null } + + val baseModels = buildList { + val n = "general.base_model.count".u32() ?: 0 + for (i in 0 until n) { + fun k(s: String) = "general.base_model.$i.$s" + add( + GgufMetadata.BaseModelInfo( + name = k("name").str(), + author = k("author").str(), + version = k("version").str(), + organization = k("organization").str(), + url = k("url").str(), + doi = k("doi").str(), + uuid = k("uuid").str(), + repoUrl = k("repo_url").str(), + ) + ) + } + }.takeIf { it.isNotEmpty() } + + val tokenizer = GgufMetadata.TokenizerInfo( + model = "tokenizer.ggml.model".str(), + bosTokenId = "tokenizer.ggml.bos_token_id".u32(), + eosTokenId = "tokenizer.ggml.eos_token_id".u32(), + unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(), + paddingTokenId = "tokenizer.ggml.padding_token_id".u32(), + addBosToken = "tokenizer.ggml.add_bos_token".bool(), + addEosToken = "tokenizer.ggml.add_eos_token".bool(), + chatTemplate = "tokenizer.chat_template".str() + ).takeUnless { model == null && bosTokenId == null && eosTokenId == null && + unknownTokenId == null && paddingTokenId == null && + addBosToken == null && addEosToken == null && chatTemplate == null + } + + val dimensions = GgufMetadata.DimensionsInfo( + contextLength = "$arch.context_length".u32(), + embeddingSize = "$arch.embedding_length".u32(), + blockCount = "$arch.block_count".u32(), + feedForwardSize = "$arch.feed_forward_length".u32() + ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null } + + val attention = GgufMetadata.AttentionInfo( + headCount = "$arch.attention.head_count".u32(), + headCountKv = "$arch.attention.head_count_kv".u32(), + keyLength = "$arch.attention.key_length".u32(), + valueLength = "$arch.attention.value_length".u32(), + layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(), + layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(), + ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null && + layerNormEpsilon == null && layerNormRmsEpsilon == null + } + + val rope = GgufMetadata.RopeInfo( + frequencyBase = "$arch.rope.freq_base".f32(), + dimensionCount = "$arch.rope.dimension_count".u32(), + scalingType = "$arch.rope.scaling.type".str(), + scalingFactor = "$arch.rope.scaling.factor".f32(), + attnFactor = "$arch.rope.scaling.attn_factor".f32(), + originalContextLength = "$arch.rope.scaling.original_context_length".u32(), + finetuned = "$arch.rope.scaling.finetuned".bool() + ).takeUnless { frequencyBase == null && dimensionCount == null && + scalingType == null && scalingFactor == null && attnFactor == null && + originalContextLength == null && finetuned == null + } + + val experts = GgufMetadata.ExpertsInfo( + count = "$arch.expert_count".u32(), + usedCount = "$arch.expert_used_count".u32() + ).takeUnless { count == null && usedCount == null } + + return GgufMetadata( + version = version, + tensorCount = tensorCnt, + kvCount = kvCnt, + basic = basic, + author = author, + additional = additional, + architecture = architectureInfo, + baseModels = baseModels, + tokenizer = tokenizer, + dimensions = dimensions, + attention = attention, + rope = rope, + experts = experts + ) + } + + /** + * Recursively parses a metadata value of the given type from the input stream. + * @param input The input stream positioned at the start of the value. + * @param type The metadata value type to parse. + */ + private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) { + MetadataType.UINT8 -> { + // 1-byte unsigned integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.") + MetadataValue.UInt8(byteVal.toUByte()) + } + MetadataType.INT8 -> { + // 1-byte signed integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.") + MetadataValue.Int8(byteVal.toByte()) + } + MetadataType.UINT16 -> { + // 2-byte unsigned integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.") + // Combine two bytes (little-endian) into an unsigned 16-bit value + val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.UInt16(u16.toUShort()) + } + MetadataType.INT16 -> { + // 2-byte signed integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.") + // Combine to 16-bit and interpret as signed + val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.Int16(i16.toShort()) + } + MetadataType.UINT32 -> { + // 4-byte unsigned integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.") + // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt + val u32 = (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.UInt32(u32.toUInt()) + } + MetadataType.INT32 -> { + // 4-byte signed integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.") + // Combine four bytes into a 32-bit signed int + val i32 = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + MetadataValue.Int32(i32) + } + MetadataType.FLOAT32 -> { + // 4-byte IEEE 754 float (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.") + // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float + val bits = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + val floatVal = Float.fromBits(bits) + MetadataValue.Float32(floatVal) + } + MetadataType.BOOL -> { + // 1-byte boolean (0 = false, 1 = true) + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.") + if (byteVal != 0 && byteVal != 1) { + throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).") + } + MetadataValue.Bool(byteVal != 0) + } + MetadataType.STRING -> { + // UTF-8 string (length-prefixed with 8-byte length) + val str = readString(input) + MetadataValue.StringVal(str) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + val count = len.toInt() + + if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) { + // fast‑forward without allocation + repeat(count) { skipValue(input, elemType) } + MetadataValue.StringVal("Array($elemType, $count items) /* summarised */") + } else { + val list = ArrayList<MetadataValue>(count) + repeat(count) { list += parseValue(input, elemType) } + MetadataValue.ArrayVal(elemType, list) + } + } + MetadataType.UINT64 -> { + // 8-byte unsigned integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.") + // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range. + val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or + (bytes[6].toULong() and 0xFFuL shl 48) or + (bytes[5].toULong() and 0xFFuL shl 40) or + (bytes[4].toULong() and 0xFFuL shl 32) or + (bytes[3].toULong() and 0xFFuL shl 24) or + (bytes[2].toULong() and 0xFFuL shl 16) or + (bytes[1].toULong() and 0xFFuL shl 8) or + (bytes[0].toULong() and 0xFFuL) + MetadataValue.UInt64(u64) + } + MetadataType.INT64 -> { + // 8-byte signed integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.") + // Combine 8 bytes into a signed 64-bit value (Long) + val i64 = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.Int64(i64) + } + MetadataType.FLOAT64 -> { + // 8-byte IEEE 754 double (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.") + // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double + val bits = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + val doubleVal = Double.fromBits(bits) + MetadataValue.Float64(doubleVal) + } + } + + + private fun <T> T?.takeUnless(check: T.() -> Boolean): T? = + this?.takeIf { !it.check() } + + /** Helper: Skip a value in the stream without storing it (still maintains pointer). */ + private fun skipValue(input: InputStream, type: MetadataType) { + when (type) { + MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1) + MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2) + MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4) + MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8) + MetadataType.STRING -> { + val len = readLittleLong(input); input.skipFully(len) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip + } + } + } + + /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */ + private fun readLittleLong(input: InputStream): Long { + val bytes = ByteArray(8) + input.readFully(bytes) + + // Combine 8 bytes into a 64-bit value (Little Endian). + // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement). + // In our context (lengths/counts), such extremely large values are not expected. + return (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + } + + /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ + private fun readString(input: InputStream): String = + // Read 8-byte little-endian length (number of bytes in the string). + readLittleLong(input).let { len -> + if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") + + // Read the UTF-8 bytes of the given length. + ByteArray(len.toInt()).let { + if (it.isNotEmpty()) input.readFully(it) + String(it, Charsets.UTF_8) + } + } + + /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ + private fun littleEndianBytesToInt(bytes: ByteArray): Int = + // Note: assumes bytes length is 4. + (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + + /** + * Robust skip that works the same on JDK 11 and Android’s desugared runtime. + * + * @param n Number of bytes to advance in the stream. + * @throws IOException on premature EOF. + */ + private fun InputStream.skipFully(n: Long) { + var remaining = n + val scratch = ByteArray(8192) // read‑and‑toss buffer + while (remaining > 0) { + val skipped = skip(remaining) + when { + skipped > 0 -> remaining -= skipped // normal fast path + skipped == 0L -> { + // fallback: read and discard + val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt()) + if (read == -1) throw IOException("EOF while skipping $n bytes") + remaining -= read + } + else -> throw IOException("Skip returned negative value") + } + } + } + + /** + * Extension that keeps reading until the requested number of bytes are filled. + * Falls back to `read()` when `skip()` returns 0, which happens on some Android + * streams. + * + * @param buf Destination buffer. + * @param len Number of bytes to fill (defaults to `buf.size`). + * @throws IOException on premature EOF. + */ + private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) { + var off = 0 + while (off < len) { + val n = read(buf, off, len - off) + if (n == -1) throw IOException("EOF after $off of $len bytes") + off += n + } + } + + /** + * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array. + * This is used for small fixed‑length reads (e.g. 4‑byte type codes). + * + * @throws IOException on premature EOF. + */ + private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also { + if (read(it) != n) throw IOException("Unexpected EOF") + } +} |
