diff options
| author | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
|---|---|---|
| committer | Mitja Felicijan <mitja.felicijan@gmail.com> | 2026-02-12 20:57:17 +0100 |
| commit | b333b06772c89d96aacb5490d6a219fba7c09cc6 (patch) | |
| tree | 211df60083a5946baa2ed61d33d8121b7e251b06 /llama.cpp/examples/llama.android | |
| download | llmnpc-b333b06772c89d96aacb5490d6a219fba7c09cc6.tar.gz | |
Engage!
Diffstat (limited to 'llama.cpp/examples/llama.android')
57 files changed, 3293 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/.gitignore b/llama.cpp/examples/llama.android/.gitignore new file mode 100644 index 0000000..347e252 --- /dev/null +++ b/llama.cpp/examples/llama.android/.gitignore @@ -0,0 +1,33 @@ +# Gradle files +.gradle/ +build/ + +# Local configuration file (sdk path, etc) +local.properties + +# Log/OS Files +*.log + +# Android Studio generated files and folders +captures/ +.externalNativeBuild/ +.cxx/ +*.apk +output.json + +# IntelliJ +*.iml +.idea/ +misc.xml +deploymentTargetDropDown.xml +render.experimental.xml + +# Keystore files +*.jks +*.keystore + +# Google Services (e.g. APIs or Firebase) +google-services.json + +# Android Profiling +*.hprof diff --git a/llama.cpp/examples/llama.android/app/.gitignore b/llama.cpp/examples/llama.android/app/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/.gitignore @@ -0,0 +1 @@ +/build diff --git a/llama.cpp/examples/llama.android/app/build.gradle.kts b/llama.cpp/examples/llama.android/app/build.gradle.kts new file mode 100644 index 0000000..2edfe98 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/build.gradle.kts @@ -0,0 +1,58 @@ +plugins { + alias(libs.plugins.android.application) + alias(libs.plugins.jetbrains.kotlin.android) +} + +android { + namespace = "com.example.llama" + compileSdk = 36 + + defaultConfig { + applicationId = "com.example.llama.aichat" + + minSdk = 33 + targetSdk = 36 + + versionCode = 1 + versionName = "1.0" + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + vectorDrawables { + useSupportLibrary = true + } + } + + buildTypes { + debug { + isMinifyEnabled = true + isShrinkResources = true + proguardFiles( + getDefaultProguardFile("proguard-android.txt"), + "proguard-rules.pro" + ) + } + release { + isMinifyEnabled = true + isShrinkResources = true + proguardFiles( + getDefaultProguardFile("proguard-android-optimize.txt"), + "proguard-rules.pro" + ) + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } +} + +dependencies { + implementation(libs.bundles.androidx) + implementation(libs.material) + + implementation(project(":lib")) + + testImplementation(libs.junit) + androidTestImplementation(libs.androidx.junit) + androidTestImplementation(libs.androidx.espresso.core) +} diff --git a/llama.cpp/examples/llama.android/app/proguard-rules.pro b/llama.cpp/examples/llama.android/app/proguard-rules.pro new file mode 100644 index 0000000..358020d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/proguard-rules.pro @@ -0,0 +1,29 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile + +-keep class com.arm.aichat.* { *; } +-keep class com.arm.aichat.gguf.* { *; } + +-assumenosideeffects class android.util.Log { + public static int v(...); + public static int d(...); +} diff --git a/llama.cpp/examples/llama.android/app/src/main/AndroidManifest.xml b/llama.cpp/examples/llama.android/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000..8f7c606 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/AndroidManifest.xml @@ -0,0 +1,27 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android"> + + <application + android:allowBackup="true" + android:dataExtractionRules="@xml/data_extraction_rules" + android:extractNativeLibs="true" + android:fullBackupContent="@xml/backup_rules" + android:icon="@mipmap/ic_launcher_round" + android:label="@string/app_name" + android:roundIcon="@mipmap/ic_launcher_round" + android:supportsRtl="true" + android:theme="@style/Theme.AiChatSample" + > + + <activity + android:name=".MainActivity" + android:exported="true"> + <intent-filter> + <action android:name="android.intent.action.MAIN" /> + + <category android:name="android.intent.category.LAUNCHER" /> + </intent-filter> + </activity> + </application> + +</manifest> diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt new file mode 100644 index 0000000..872ec2b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt @@ -0,0 +1,275 @@ +package com.example.llama + +import android.net.Uri +import android.os.Bundle +import android.util.Log +import android.widget.EditText +import android.widget.TextView +import android.widget.Toast +import androidx.activity.addCallback +import androidx.activity.enableEdgeToEdge +import androidx.activity.result.contract.ActivityResultContracts +import androidx.appcompat.app.AppCompatActivity +import androidx.lifecycle.lifecycleScope +import androidx.recyclerview.widget.LinearLayoutManager +import androidx.recyclerview.widget.RecyclerView +import com.arm.aichat.AiChat +import com.arm.aichat.InferenceEngine +import com.arm.aichat.gguf.GgufMetadata +import com.arm.aichat.gguf.GgufMetadataReader +import com.google.android.material.floatingactionbutton.FloatingActionButton +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.flow.onCompletion +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext +import java.io.File +import java.io.FileOutputStream +import java.io.InputStream +import java.util.UUID + +class MainActivity : AppCompatActivity() { + + // Android views + private lateinit var ggufTv: TextView + private lateinit var messagesRv: RecyclerView + private lateinit var userInputEt: EditText + private lateinit var userActionFab: FloatingActionButton + + // Arm AI Chat inference engine + private lateinit var engine: InferenceEngine + private var generationJob: Job? = null + + // Conversation states + private var isModelReady = false + private val messages = mutableListOf<Message>() + private val lastAssistantMsg = StringBuilder() + private val messageAdapter = MessageAdapter(messages) + + override fun onCreate(savedInstanceState: Bundle?) { + super.onCreate(savedInstanceState) + enableEdgeToEdge() + setContentView(R.layout.activity_main) + // View model boilerplate and state management is out of this basic sample's scope + onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") } + + // Find views + ggufTv = findViewById(R.id.gguf) + messagesRv = findViewById(R.id.messages) + messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true } + messagesRv.adapter = messageAdapter + userInputEt = findViewById(R.id.user_input) + userActionFab = findViewById(R.id.fab) + + // Arm AI Chat initialization + lifecycleScope.launch(Dispatchers.Default) { + engine = AiChat.getInferenceEngine(applicationContext) + } + + // Upon CTA button tapped + userActionFab.setOnClickListener { + if (isModelReady) { + // If model is ready, validate input and send to engine + handleUserInput() + } else { + // Otherwise, prompt user to select a GGUF metadata on the device + getContent.launch(arrayOf("*/*")) + } + } + } + + private val getContent = registerForActivityResult( + ActivityResultContracts.OpenDocument() + ) { uri -> + Log.i(TAG, "Selected file uri:\n $uri") + uri?.let { handleSelectedModel(it) } + } + + /** + * Handles the file Uri from [getContent] result + */ + private fun handleSelectedModel(uri: Uri) { + // Update UI states + userActionFab.isEnabled = false + userInputEt.hint = "Parsing GGUF..." + ggufTv.text = "Parsing metadata from selected file \n$uri" + + lifecycleScope.launch(Dispatchers.IO) { + // Parse GGUF metadata + Log.i(TAG, "Parsing GGUF metadata...") + contentResolver.openInputStream(uri)?.use { + GgufMetadataReader.create().readStructuredMetadata(it) + }?.let { metadata -> + // Update UI to show GGUF metadata to user + Log.i(TAG, "GGUF parsed: \n$metadata") + withContext(Dispatchers.Main) { + ggufTv.text = metadata.toString() + } + + // Ensure the model file is available + val modelName = metadata.filename() + FILE_EXTENSION_GGUF + contentResolver.openInputStream(uri)?.use { input -> + ensureModelFile(modelName, input) + }?.let { modelFile -> + loadModel(modelName, modelFile) + + withContext(Dispatchers.Main) { + isModelReady = true + userInputEt.hint = "Type and send a message!" + userInputEt.isEnabled = true + userActionFab.setImageResource(R.drawable.outline_send_24) + userActionFab.isEnabled = true + } + } + } + } + } + + /** + * Prepare the model file within app's private storage + */ + private suspend fun ensureModelFile(modelName: String, input: InputStream) = + withContext(Dispatchers.IO) { + File(ensureModelsDirectory(), modelName).also { file -> + // Copy the file into local storage if not yet done + if (!file.exists()) { + Log.i(TAG, "Start copying file to $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Copying file..." + } + + FileOutputStream(file).use { input.copyTo(it) } + Log.i(TAG, "Finished copying file to $modelName") + } else { + Log.i(TAG, "File already exists $modelName") + } + } + } + + /** + * Load the model file from the app private storage + */ + private suspend fun loadModel(modelName: String, modelFile: File) = + withContext(Dispatchers.IO) { + Log.i(TAG, "Loading model $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Loading model..." + } + engine.loadModel(modelFile.path) + } + + /** + * Validate and send the user message into [InferenceEngine] + */ + private fun handleUserInput() { + userInputEt.text.toString().also { userMsg -> + if (userMsg.isEmpty()) { + Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show() + } else { + userInputEt.text = null + userInputEt.isEnabled = false + userActionFab.isEnabled = false + + // Update message states + messages.add(Message(UUID.randomUUID().toString(), userMsg, true)) + lastAssistantMsg.clear() + messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false)) + + generationJob = lifecycleScope.launch(Dispatchers.Default) { + engine.sendUserPrompt(userMsg) + .onCompletion { + withContext(Dispatchers.Main) { + userInputEt.isEnabled = true + userActionFab.isEnabled = true + } + }.collect { token -> + withContext(Dispatchers.Main) { + val messageCount = messages.size + check(messageCount > 0 && !messages[messageCount - 1].isUser) + + messages.removeAt(messageCount - 1).copy( + content = lastAssistantMsg.append(token).toString() + ).let { messages.add(it) } + + messageAdapter.notifyItemChanged(messages.size - 1) + } + } + } + } + } + } + + /** + * Run a benchmark with the model file + */ + @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers") + private suspend fun runBenchmark(modelName: String, modelFile: File) = + withContext(Dispatchers.Default) { + Log.i(TAG, "Starts benchmarking $modelName") + withContext(Dispatchers.Main) { + userInputEt.hint = "Running benchmark..." + } + engine.bench( + pp=BENCH_PROMPT_PROCESSING_TOKENS, + tg=BENCH_TOKEN_GENERATION_TOKENS, + pl=BENCH_SEQUENCE, + nr=BENCH_REPETITION + ).let { result -> + messages.add(Message(UUID.randomUUID().toString(), result, false)) + withContext(Dispatchers.Main) { + messageAdapter.notifyItemChanged(messages.size - 1) + } + } + } + + /** + * Create the `models` directory if not exist. + */ + private fun ensureModelsDirectory() = + File(filesDir, DIRECTORY_MODELS).also { + if (it.exists() && !it.isDirectory) { it.delete() } + if (!it.exists()) { it.mkdir() } + } + + override fun onStop() { + generationJob?.cancel() + super.onStop() + } + + override fun onDestroy() { + engine.destroy() + super.onDestroy() + } + + companion object { + private val TAG = MainActivity::class.java.simpleName + + private const val DIRECTORY_MODELS = "models" + private const val FILE_EXTENSION_GGUF = ".gguf" + + private const val BENCH_PROMPT_PROCESSING_TOKENS = 512 + private const val BENCH_TOKEN_GENERATION_TOKENS = 128 + private const val BENCH_SEQUENCE = 1 + private const val BENCH_REPETITION = 3 + } +} + +fun GgufMetadata.filename() = when { + basic.name != null -> { + basic.name?.let { name -> + basic.sizeLabel?.let { size -> + "$name-$size" + } ?: name + } + } + architecture?.architecture != null -> { + architecture?.architecture?.let { arch -> + basic.uuid?.let { uuid -> + "$arch-$uuid" + } ?: "$arch-${System.currentTimeMillis()}" + } + } + else -> { + "model-${System.currentTimeMillis().toHexString()}" + } +} diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt new file mode 100644 index 0000000..0439f96 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt @@ -0,0 +1,51 @@ +package com.example.llama + +import android.view.LayoutInflater +import android.view.View +import android.view.ViewGroup +import android.widget.TextView +import androidx.recyclerview.widget.RecyclerView + +data class Message( + val id: String, + val content: String, + val isUser: Boolean +) + +class MessageAdapter( + private val messages: List<Message> +) : RecyclerView.Adapter<RecyclerView.ViewHolder>() { + + companion object { + private const val VIEW_TYPE_USER = 1 + private const val VIEW_TYPE_ASSISTANT = 2 + } + + override fun getItemViewType(position: Int): Int { + return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT + } + + override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder { + val layoutInflater = LayoutInflater.from(parent.context) + return if (viewType == VIEW_TYPE_USER) { + val view = layoutInflater.inflate(R.layout.item_message_user, parent, false) + UserMessageViewHolder(view) + } else { + val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false) + AssistantMessageViewHolder(view) + } + } + + override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) { + val message = messages[position] + if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) { + val textView = holder.itemView.findViewById<TextView>(R.id.msg_content) + textView.text = message.content + } + } + + override fun getItemCount(): Int = messages.size + + class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) + class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) +} diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml new file mode 100644 index 0000000..f90c3db --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_assistant_message.xml @@ -0,0 +1,4 @@ +<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle"> + <solid android:color="#E5E5EA" /> + <corners android:radius="16dp" /> +</shape> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml new file mode 100644 index 0000000..3ca7dae --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/bg_user_message.xml @@ -0,0 +1,4 @@ +<shape xmlns:android="http://schemas.android.com/apk/res/android" android:shape="rectangle"> + <solid android:color="#4285F4" /> + <corners android:radius="16dp" /> +</shape> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml new file mode 100644 index 0000000..07d5da9 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_background.xml @@ -0,0 +1,170 @@ +<?xml version="1.0" encoding="utf-8"?> +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="108dp" + android:height="108dp" + android:viewportWidth="108" + android:viewportHeight="108"> + <path + android:fillColor="#3DDC84" + android:pathData="M0,0h108v108h-108z" /> + <path + android:fillColor="#00000000" + android:pathData="M9,0L9,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,0L19,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M29,0L29,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M39,0L39,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M49,0L49,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M59,0L59,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M69,0L69,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M79,0L79,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M89,0L89,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M99,0L99,108" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,9L108,9" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,19L108,19" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,29L108,29" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,39L108,39" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,49L108,49" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,59L108,59" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,69L108,69" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,79L108,79" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,89L108,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M0,99L108,99" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,29L89,29" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,39L89,39" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,49L89,49" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,59L89,59" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,69L89,69" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M19,79L89,79" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M29,19L29,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M39,19L39,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M49,19L49,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M59,19L59,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M69,19L69,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> + <path + android:fillColor="#00000000" + android:pathData="M79,19L79,89" + android:strokeWidth="0.8" + android:strokeColor="#33FFFFFF" /> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml new file mode 100644 index 0000000..7706ab9 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/ic_launcher_foreground.xml @@ -0,0 +1,30 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:aapt="http://schemas.android.com/aapt" + android:width="108dp" + android:height="108dp" + android:viewportWidth="108" + android:viewportHeight="108"> + <path android:pathData="M31,63.928c0,0 6.4,-11 12.1,-13.1c7.2,-2.6 26,-1.4 26,-1.4l38.1,38.1L107,108.928l-32,-1L31,63.928z"> + <aapt:attr name="android:fillColor"> + <gradient + android:endX="85.84757" + android:endY="92.4963" + android:startX="42.9492" + android:startY="49.59793" + android:type="linear"> + <item + android:color="#44000000" + android:offset="0.0" /> + <item + android:color="#00000000" + android:offset="1.0" /> + </gradient> + </aapt:attr> + </path> + <path + android:fillColor="#FFFFFF" + android:fillType="nonZero" + android:pathData="M65.3,45.828l3.8,-6.6c0.2,-0.4 0.1,-0.9 -0.3,-1.1c-0.4,-0.2 -0.9,-0.1 -1.1,0.3l-3.9,6.7c-6.3,-2.8 -13.4,-2.8 -19.7,0l-3.9,-6.7c-0.2,-0.4 -0.7,-0.5 -1.1,-0.3C38.8,38.328 38.7,38.828 38.9,39.228l3.8,6.6C36.2,49.428 31.7,56.028 31,63.928h46C76.3,56.028 71.8,49.428 65.3,45.828zM43.4,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2c-0.3,-0.7 -0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C45.3,56.528 44.5,57.328 43.4,57.328L43.4,57.328zM64.6,57.328c-0.8,0 -1.5,-0.5 -1.8,-1.2s-0.1,-1.5 0.4,-2.1c0.5,-0.5 1.4,-0.7 2.1,-0.4c0.7,0.3 1.2,1 1.2,1.8C66.5,56.528 65.6,57.328 64.6,57.328L64.6,57.328z" + android:strokeWidth="1" + android:strokeColor="#00000000" /> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml new file mode 100644 index 0000000..f58b501 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_folder_open_24.xml @@ -0,0 +1,10 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="24dp" + android:height="24dp" + android:viewportWidth="24" + android:viewportHeight="24" + android:tint="?attr/colorControlNormal"> + <path + android:fillColor="@android:color/white" + android:pathData="M20,6h-8l-2,-2L4,4c-1.1,0 -1.99,0.9 -1.99,2L2,18c0,1.1 0.9,2 2,2h16c1.1,0 2,-0.9 2,-2L22,8c0,-1.1 -0.9,-2 -2,-2zM20,18L4,18L4,8h16v10z"/> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml new file mode 100644 index 0000000..712adc0 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/drawable/outline_send_24.xml @@ -0,0 +1,11 @@ +<vector xmlns:android="http://schemas.android.com/apk/res/android" + android:width="24dp" + android:height="24dp" + android:viewportWidth="24" + android:viewportHeight="24" + android:tint="?attr/colorControlNormal" + android:autoMirrored="true"> + <path + android:fillColor="@android:color/white" + android:pathData="M4.01,6.03l7.51,3.22 -7.52,-1 0.01,-2.22m7.5,8.72L4,17.97v-2.22l7.51,-1M2.01,3L2,10l15,2 -15,2 0.01,7L23,12 2.01,3z"/> +</vector> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/layout/activity_main.xml b/llama.cpp/examples/llama.android/app/src/main/res/layout/activity_main.xml new file mode 100644 index 0000000..d15772b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/layout/activity_main.xml @@ -0,0 +1,77 @@ +<?xml version="1.0" encoding="utf-8"?> +<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android" + xmlns:app="http://schemas.android.com/apk/res-auto" + xmlns:tools="http://schemas.android.com/tools" + android:id="@+id/main" + android:layout_height="match_parent" + android:layout_width="match_parent"> + + <LinearLayout + android:fitsSystemWindows="true" + android:layout_width="match_parent" + android:layout_height="match_parent" + android:orientation="vertical" + android:layout_marginEnd="4dp" + tools:context=".MainActivity"> + + <ScrollView + android:layout_width="match_parent" + android:layout_height="0dp" + android:layout_weight="1" + android:fadeScrollbars="false"> + + <TextView + android:id="@+id/gguf" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:padding="16dp" + android:text="Selected GGUF model's metadata will show here." + style="@style/TextAppearance.MaterialComponents.Body2" /> + + </ScrollView> + + <com.google.android.material.divider.MaterialDivider + android:layout_width="match_parent" + android:layout_height="2dp" + android:layout_marginHorizontal="16dp" /> + + <androidx.recyclerview.widget.RecyclerView + android:id="@+id/messages" + android:layout_width="match_parent" + android:layout_height="0dp" + android:layout_weight="4" + android:fadeScrollbars="false" + android:scrollbars="vertical" + app:reverseLayout="true" + tools:listitem="@layout/item_message_assistant"/> + + <LinearLayout + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:orientation="horizontal" + android:paddingStart="16dp" + android:paddingEnd="4dp"> + + <EditText + android:id="@+id/user_input" + android:enabled="false" + android:layout_width="0dp" + android:layout_weight="1" + android:layout_height="match_parent" + android:padding="8dp" + style="@style/TextAppearance.MaterialComponents.Body2" + android:hint="Please first pick a GGUF model file to import." /> + + <com.google.android.material.floatingactionbutton.FloatingActionButton + android:id="@+id/fab" + android:enabled="true" + style="@style/Widget.Material3.FloatingActionButton.Primary" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:layout_margin="12dp" + android:src="@drawable/outline_folder_open_24" /> + + </LinearLayout> + + </LinearLayout> +</androidx.constraintlayout.widget.ConstraintLayout> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml new file mode 100644 index 0000000..2c8e4bc --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_assistant.xml @@ -0,0 +1,16 @@ +<?xml version="1.0" encoding="utf-8"?> +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:layout_marginHorizontal="16dp" + android:layout_marginVertical="8dp" + android:gravity="start"> + + <TextView + android:id="@+id/msg_content" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:background="@drawable/bg_assistant_message" + android:padding="12dp" + android:textColor="@android:color/black" /> +</LinearLayout> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_user.xml b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_user.xml new file mode 100644 index 0000000..5aa79f2 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/layout/item_message_user.xml @@ -0,0 +1,16 @@ +<?xml version="1.0" encoding="utf-8"?> +<LinearLayout xmlns:android="http://schemas.android.com/apk/res/android" + android:layout_width="match_parent" + android:layout_height="wrap_content" + android:layout_marginHorizontal="16dp" + android:layout_marginVertical="8dp" + android:gravity="end"> + + <TextView + android:id="@+id/msg_content" + android:layout_width="wrap_content" + android:layout_height="wrap_content" + android:background="@drawable/bg_user_message" + android:padding="12dp" + android:textColor="@android:color/white" /> +</LinearLayout> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml new file mode 100644 index 0000000..b3e26b4 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android"> + <background android:drawable="@drawable/ic_launcher_background" /> + <foreground android:drawable="@drawable/ic_launcher_foreground" /> + <monochrome android:drawable="@drawable/ic_launcher_foreground" /> +</adaptive-icon> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml new file mode 100644 index 0000000..b3e26b4 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-anydpi/ic_launcher_round.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="utf-8"?> +<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android"> + <background android:drawable="@drawable/ic_launcher_background" /> + <foreground android:drawable="@drawable/ic_launcher_foreground" /> + <monochrome android:drawable="@drawable/ic_launcher_foreground" /> +</adaptive-icon> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..c209e78 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..b2dfe3d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-hdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..4f0f1d6 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..62b611d --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-mdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..948a307 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..1b9a695 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xhdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..28d4b77 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..9287f50 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxhdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp Binary files differnew file mode 100644 index 0000000..aa7d642 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp Binary files differnew file mode 100644 index 0000000..9126ae3 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/mipmap-xxxhdpi/ic_launcher_round.webp diff --git a/llama.cpp/examples/llama.android/app/src/main/res/values/colors.xml b/llama.cpp/examples/llama.android/app/src/main/res/values/colors.xml new file mode 100644 index 0000000..ca1931b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/values/colors.xml @@ -0,0 +1,10 @@ +<?xml version="1.0" encoding="utf-8"?> +<resources> + <color name="purple_200">#FFBB86FC</color> + <color name="purple_500">#FF6200EE</color> + <color name="purple_700">#FF3700B3</color> + <color name="teal_200">#FF03DAC5</color> + <color name="teal_700">#FF018786</color> + <color name="black">#FF000000</color> + <color name="white">#FFFFFFFF</color> +</resources> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/values/strings.xml b/llama.cpp/examples/llama.android/app/src/main/res/values/strings.xml new file mode 100644 index 0000000..36059fc --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/values/strings.xml @@ -0,0 +1,3 @@ +<resources> + <string name="app_name">AI Chat basic sample</string> +</resources> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/values/themes.xml b/llama.cpp/examples/llama.android/app/src/main/res/values/themes.xml new file mode 100644 index 0000000..2e4fdad --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/values/themes.xml @@ -0,0 +1,10 @@ +<?xml version="1.0" encoding="utf-8"?> +<resources> + + <style name="Base.Theme.AiChatSample" parent="Theme.Material3.DayNight.NoActionBar"> + <!-- Customize your light theme here. --> + <!-- <item name="colorPrimary">@color/my_light_primary</item> --> + </style> + + <style name="Theme.AiChatSample" parent="Base.Theme.AiChatSample" /> +</resources> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/xml/backup_rules.xml b/llama.cpp/examples/llama.android/app/src/main/res/xml/backup_rules.xml new file mode 100644 index 0000000..148c18b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/xml/backup_rules.xml @@ -0,0 +1,13 @@ +<?xml version="1.0" encoding="utf-8"?><!-- + Sample backup rules file; uncomment and customize as necessary. + See https://developer.android.com/guide/topics/data/autobackup + for details. + Note: This file is ignored for devices older that API 31 + See https://developer.android.com/about/versions/12/backup-restore +--> +<full-backup-content> + <!-- + <include domain="sharedpref" path="."/> + <exclude domain="sharedpref" path="device.xml"/> +--> +</full-backup-content> diff --git a/llama.cpp/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml b/llama.cpp/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml new file mode 100644 index 0000000..0c4f95c --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/res/xml/data_extraction_rules.xml @@ -0,0 +1,19 @@ +<?xml version="1.0" encoding="utf-8"?><!-- + Sample data extraction rules file; uncomment and customize as necessary. + See https://developer.android.com/about/versions/12/backup-restore#xml-changes + for details. +--> +<data-extraction-rules> + <cloud-backup> + <!-- TODO: Use <include> and <exclude> to control what is backed up. + <include .../> + <exclude .../> + --> + </cloud-backup> + <!-- + <device-transfer> + <include .../> + <exclude .../> + </device-transfer> + --> +</data-extraction-rules> diff --git a/llama.cpp/examples/llama.android/build.gradle.kts b/llama.cpp/examples/llama.android/build.gradle.kts new file mode 100644 index 0000000..076a0f1 --- /dev/null +++ b/llama.cpp/examples/llama.android/build.gradle.kts @@ -0,0 +1,6 @@ +// Top-level build file where you can add configuration options common to all sub-projects/modules. +plugins { + alias(libs.plugins.android.application) apply false + alias(libs.plugins.android.library) apply false + alias(libs.plugins.jetbrains.kotlin.android) apply false +} diff --git a/llama.cpp/examples/llama.android/gradle.properties b/llama.cpp/examples/llama.android/gradle.properties new file mode 100644 index 0000000..8888cc9 --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle.properties @@ -0,0 +1,24 @@ +# Project-wide Gradle settings. +# IDE (e.g. Android Studio) users: +# Gradle settings configured through the IDE *will override* +# any settings specified in this file. +# For more details on how to configure your build environment visit +# http://www.gradle.org/docs/current/userguide/build_environment.html +# Specifies the JVM arguments used for the daemon process. +# The setting is particularly useful for tweaking memory settings. +org.gradle.jvmargs=-Xmx2048m -Dfile.encoding=UTF-8 +# When configured, Gradle will run in incubating parallel mode. +# This option should only be used with decoupled projects. More details, visit +# http://www.gradle.org/docs/current/userguide/multi_project_builds.html#sec:decoupled_projects +# org.gradle.parallel=true +# AndroidX package structure to make it clearer which packages are bundled with the +# Android operating system, and which are packaged with your app's APK +# https://developer.android.com/topic/libraries/support-library/androidx-rn +android.useAndroidX=true +# Kotlin code style for this project: "official" or "obsolete": +kotlin.code.style=official +# Enables namespacing of each library's R class so that its R class includes only the +# resources declared in the library itself and none from the library's dependencies, +# thereby reducing the size of the R class for that library +android.nonTransitiveRClass=true +android.native.buildOutput=verbose diff --git a/llama.cpp/examples/llama.android/gradle/libs.versions.toml b/llama.cpp/examples/llama.android/gradle/libs.versions.toml new file mode 100644 index 0000000..8ff2afd --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle/libs.versions.toml @@ -0,0 +1,53 @@ +[versions] + +# Plugins +agp = "8.13.2" +kotlin = "2.3.0" + +# AndroidX +activity = "1.12.2" +appcompat = "1.7.1" +core-ktx = "1.17.0" +constraint-layout = "2.2.1" +datastore-preferences = "1.2.0" + +# Material +material = "1.13.0" + +# Testing +espresso-core = "3.7.0" +androidx-junit = "1.3.0" +junit = "4.13.2" + + +[plugins] +android-application = { id = "com.android.application", version.ref = "agp" } +android-library = { id = "com.android.library", version.ref = "agp" } +jetbrains-kotlin-android = { id = "org.jetbrains.kotlin.android", version.ref = "kotlin" } + + +[libraries] + +# AndroidX +androidx-activity = { group = "androidx.activity", name = "activity", version.ref = "activity" } +androidx-appcompat = { group = "androidx.appcompat", name = "appcompat", version.ref = "appcompat" } +androidx-constraintlayout = { group = "androidx.constraintlayout", name = "constraintlayout", version.ref = "constraint-layout" } +androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "core-ktx" } +androidx-datastore-preferences = { group = "androidx.datastore", name = "datastore-preferences", version.ref = "datastore-preferences" } + +#Material +material = { group = "com.google.android.material", name = "material", version.ref = "material" } + +# Testing +androidx-espresso-core = { group = "androidx.test.espresso", name = "espresso-core", version.ref = "espresso-core" } +androidx-junit = { group = "androidx.test.ext", name = "junit", version.ref = "androidx-junit" } +junit = { group = "junit", name = "junit", version.ref = "junit" } + +[bundles] +androidx = [ + "androidx-activity", + "androidx-appcompat", + "androidx-constraintlayout", + "androidx-core-ktx", + "androidx-datastore-preferences", +] diff --git a/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.jar b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.jar Binary files differnew file mode 100644 index 0000000..e708b1c --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.jar diff --git a/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.properties b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..6b993e9 --- /dev/null +++ b/llama.cpp/examples/llama.android/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Tue Apr 01 11:15:06 PDT 2025 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14.3-bin.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/llama.cpp/examples/llama.android/gradlew b/llama.cpp/examples/llama.android/gradlew new file mode 100755 index 0000000..4f906e0 --- /dev/null +++ b/llama.cpp/examples/llama.android/gradlew @@ -0,0 +1,185 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/llama.cpp/examples/llama.android/lib/.gitignore b/llama.cpp/examples/llama.android/lib/.gitignore new file mode 100644 index 0000000..796b96d --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/.gitignore @@ -0,0 +1 @@ +/build diff --git a/llama.cpp/examples/llama.android/lib/build.gradle.kts b/llama.cpp/examples/llama.android/lib/build.gradle.kts new file mode 100644 index 0000000..9b290d6 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/build.gradle.kts @@ -0,0 +1,78 @@ +plugins { + alias(libs.plugins.android.library) + alias(libs.plugins.jetbrains.kotlin.android) +} + +android { + namespace = "com.arm.aichat" + compileSdk = 36 + + ndkVersion = "29.0.13113456" + + defaultConfig { + minSdk = 33 + + testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner" + consumerProguardFiles("consumer-rules.pro") + + ndk { + abiFilters += listOf("arm64-v8a", "x86_64") + } + externalNativeBuild { + cmake { + arguments += "-DCMAKE_BUILD_TYPE=Release" + arguments += "-DCMAKE_MESSAGE_LOG_LEVEL=DEBUG" + arguments += "-DCMAKE_VERBOSE_MAKEFILE=ON" + + arguments += "-DBUILD_SHARED_LIBS=ON" + arguments += "-DLLAMA_BUILD_COMMON=ON" + arguments += "-DLLAMA_OPENSSL=OFF" + + arguments += "-DGGML_NATIVE=OFF" + arguments += "-DGGML_BACKEND_DL=ON" + arguments += "-DGGML_CPU_ALL_VARIANTS=ON" + arguments += "-DGGML_LLAMAFILE=OFF" + } + } + aarMetadata { + minCompileSdk = 35 + } + } + externalNativeBuild { + cmake { + path("src/main/cpp/CMakeLists.txt") + version = "3.31.6" + } + } + compileOptions { + sourceCompatibility = JavaVersion.VERSION_17 + targetCompatibility = JavaVersion.VERSION_17 + } + kotlin { + jvmToolchain(17) + + compileOptions { + targetCompatibility = JavaVersion.VERSION_17 + } + } + + packaging { + resources { + excludes += "/META-INF/{AL2.0,LGPL2.1}" + } + } + + publishing { + singleVariant("release") { + withJavadocJar() + } + } +} + +dependencies { + implementation(libs.androidx.core.ktx) + implementation(libs.androidx.datastore.preferences) + + testImplementation(libs.junit) + androidTestImplementation(libs.androidx.junit) +} diff --git a/llama.cpp/examples/llama.android/lib/consumer-rules.pro b/llama.cpp/examples/llama.android/lib/consumer-rules.pro new file mode 100644 index 0000000..e6eb6f5 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/consumer-rules.pro @@ -0,0 +1,8 @@ +-keep class com.arm.aichat.* { *; } +-keep class com.arm.aichat.gguf.* { *; } + +-keepclasseswithmembernames class * { + native <methods>; +} + +-keep class kotlin.Metadata { *; } diff --git a/llama.cpp/examples/llama.android/lib/proguard-rules.pro b/llama.cpp/examples/llama.android/lib/proguard-rules.pro new file mode 100644 index 0000000..f1b4245 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/proguard-rules.pro @@ -0,0 +1,21 @@ +# Add project specific ProGuard rules here. +# You can control the set of applied configuration files using the +# proguardFiles setting in build.gradle. +# +# For more details, see +# http://developer.android.com/guide/developing/tools/proguard.html + +# If your project uses WebView with JS, uncomment the following +# and specify the fully qualified class name to the JavaScript interface +# class: +#-keepclassmembers class fqcn.of.javascript.interface.for.webview { +# public *; +#} + +# Uncomment this to preserve the line number information for +# debugging stack traces. +#-keepattributes SourceFile,LineNumberTable + +# If you keep the line number information, uncomment this to +# hide the original source file name. +#-renamesourcefileattribute SourceFile diff --git a/llama.cpp/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt b/llama.cpp/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt new file mode 100644 index 0000000..05d6ab5 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/androidTest/java/android/llama/cpp/ExampleInstrumentedTest.kt @@ -0,0 +1,24 @@ +package android.llama.cpp + +import androidx.test.platform.app.InstrumentationRegistry +import androidx.test.ext.junit.runners.AndroidJUnit4 + +import org.junit.Test +import org.junit.runner.RunWith + +import org.junit.Assert.* + +/** + * Instrumented test, which will execute on an Android device. + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +@RunWith(AndroidJUnit4::class) +class ExampleInstrumentedTest { + @Test + fun useAppContext() { + // Context of the app under test. + val appContext = InstrumentationRegistry.getInstrumentation().targetContext + assertEquals("android.llama.cpp.test", appContext.packageName) + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/AndroidManifest.xml b/llama.cpp/examples/llama.android/lib/src/main/AndroidManifest.xml new file mode 100644 index 0000000..8bdb7e1 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/AndroidManifest.xml @@ -0,0 +1,4 @@ +<?xml version="1.0" encoding="utf-8"?> +<manifest xmlns:android="http://schemas.android.com/apk/res/android"> + +</manifest> diff --git a/llama.cpp/examples/llama.android/lib/src/main/cpp/CMakeLists.txt b/llama.cpp/examples/llama.android/lib/src/main/cpp/CMakeLists.txt new file mode 100644 index 0000000..7862c61 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/cpp/CMakeLists.txt @@ -0,0 +1,56 @@ +cmake_minimum_required(VERSION 3.31.6) + +project("ai-chat" VERSION 1.0.0 LANGUAGES C CXX) + +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED true) + +set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}" CACHE STRING "" FORCE) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" CACHE STRING "" FORCE) + +# -------------------------------------------------------------------------- +# AI Chat library +# -------------------------------------------------------------------------- + +if(DEFINED ANDROID_ABI) + message(STATUS "Detected Android ABI: ${ANDROID_ABI}") + if(ANDROID_ABI STREQUAL "arm64-v8a") + set(GGML_SYSTEM_ARCH "ARM") + set(GGML_CPU_KLEIDIAI ON) + set(GGML_OPENMP ON) + elseif(ANDROID_ABI STREQUAL "x86_64") + set(GGML_SYSTEM_ARCH "x86") + set(GGML_CPU_KLEIDIAI OFF) + set(GGML_OPENMP OFF) + else() + message(FATAL_ERROR "Unsupported ABI: ${ANDROID_ABI}") + endif() +endif() + +set(LLAMA_SRC ${CMAKE_CURRENT_LIST_DIR}/../../../../../../) +add_subdirectory(${LLAMA_SRC} build-llama) + +add_library(${CMAKE_PROJECT_NAME} SHARED + ai_chat.cpp) + +target_compile_definitions(${CMAKE_PROJECT_NAME} PRIVATE + GGML_SYSTEM_ARCH=${GGML_SYSTEM_ARCH} + GGML_CPU_KLEIDIAI=$<BOOL:${GGML_CPU_KLEIDIAI}> + GGML_OPENMP=$<BOOL:${GGML_OPENMP}> +) + +target_include_directories(${CMAKE_PROJECT_NAME} PRIVATE + ${LLAMA_SRC} + ${LLAMA_SRC}/common + ${LLAMA_SRC}/include + ${LLAMA_SRC}/ggml/include + ${LLAMA_SRC}/ggml/src) + +target_link_libraries(${CMAKE_PROJECT_NAME} + llama + common + android + log) diff --git a/llama.cpp/examples/llama.android/lib/src/main/cpp/ai_chat.cpp b/llama.cpp/examples/llama.android/lib/src/main/cpp/ai_chat.cpp new file mode 100644 index 0000000..9e460ac --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/cpp/ai_chat.cpp @@ -0,0 +1,565 @@ +#include <android/log.h> +#include <jni.h> +#include <iomanip> +#include <cmath> +#include <string> +#include <unistd.h> +#include <sampling.h> + +#include "logging.h" +#include "chat.h" +#include "common.h" +#include "llama.h" + +template<class T> +static std::string join(const std::vector<T> &values, const std::string &delim) { + std::ostringstream str; + for (size_t i = 0; i < values.size(); i++) { + str << values[i]; + if (i < values.size() - 1) { str << delim; } + } + return str.str(); +} + +/** + * LLama resources: context, model, batch and sampler + */ +constexpr int N_THREADS_MIN = 2; +constexpr int N_THREADS_MAX = 4; +constexpr int N_THREADS_HEADROOM = 2; + +constexpr int DEFAULT_CONTEXT_SIZE = 8192; +constexpr int OVERFLOW_HEADROOM = 4; +constexpr int BATCH_SIZE = 512; +constexpr float DEFAULT_SAMPLER_TEMP = 0.3f; + +static llama_model * g_model; +static llama_context * g_context; +static llama_batch g_batch; +static common_chat_templates_ptr g_chat_templates; +static common_sampler * g_sampler; + +extern "C" +JNIEXPORT void JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_init(JNIEnv *env, jobject /*unused*/, jstring nativeLibDir) { + // Set llama log handler to Android + llama_log_set(aichat_android_log_callback, nullptr); + + // Loading all CPU backend variants + const auto *path_to_backend = env->GetStringUTFChars(nativeLibDir, 0); + LOGi("Loading backends from %s", path_to_backend); + ggml_backend_load_all_from_path(path_to_backend); + env->ReleaseStringUTFChars(nativeLibDir, path_to_backend); + + // Initialize backends + llama_backend_init(); + LOGi("Backend initiated; Log handler set."); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_load(JNIEnv *env, jobject, jstring jmodel_path) { + llama_model_params model_params = llama_model_default_params(); + + const auto *model_path = env->GetStringUTFChars(jmodel_path, 0); + LOGd("%s: Loading model from: \n%s\n", __func__, model_path); + + auto *model = llama_model_load_from_file(model_path, model_params); + env->ReleaseStringUTFChars(jmodel_path, model_path); + if (!model) { + return 1; + } + g_model = model; + return 0; +} + +static llama_context *init_context(llama_model *model, const int n_ctx = DEFAULT_CONTEXT_SIZE) { + if (!model) { + LOGe("%s: model cannot be null", __func__); + return nullptr; + } + + // Multi-threading setup + const int n_threads = std::max(N_THREADS_MIN, std::min(N_THREADS_MAX, + (int) sysconf(_SC_NPROCESSORS_ONLN) - + N_THREADS_HEADROOM)); + LOGi("%s: Using %d threads", __func__, n_threads); + + // Context parameters setup + llama_context_params ctx_params = llama_context_default_params(); + const int trained_context_size = llama_model_n_ctx_train(model); + if (n_ctx > trained_context_size) { + LOGw("%s: Model was trained with only %d context size! Enforcing %d context size...", + __func__, trained_context_size, n_ctx); + } + ctx_params.n_ctx = n_ctx; + ctx_params.n_batch = BATCH_SIZE; + ctx_params.n_ubatch = BATCH_SIZE; + ctx_params.n_threads = n_threads; + ctx_params.n_threads_batch = n_threads; + auto *context = llama_init_from_model(g_model, ctx_params); + if (context == nullptr) { + LOGe("%s: llama_new_context_with_model() returned null)", __func__); + } + return context; +} + +static common_sampler *new_sampler(float temp) { + common_params_sampling sparams; + sparams.temp = temp; + return common_sampler_init(g_model, sparams); +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_prepare(JNIEnv * /*env*/, jobject /*unused*/) { + auto *context = init_context(g_model); + if (!context) { return 1; } + g_context = context; + g_batch = llama_batch_init(BATCH_SIZE, 0, 1); + g_chat_templates = common_chat_templates_init(g_model, ""); + g_sampler = new_sampler(DEFAULT_SAMPLER_TEMP); + return 0; +} + +static std::string get_backend() { + std::vector<std::string> backends; + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto *reg = ggml_backend_reg_get(i); + std::string name = ggml_backend_reg_name(reg); + if (name != "CPU") { + backends.push_back(ggml_backend_reg_name(reg)); + } + } + return backends.empty() ? "CPU" : join(backends, ","); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_systemInfo(JNIEnv *env, jobject /*unused*/) { + return env->NewStringUTF(llama_print_system_info()); +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_benchModel(JNIEnv *env, jobject /*unused*/, jint pp, jint tg, + jint pl, jint nr) { + auto *context = init_context(g_model, pp); + if (!context) { + const auto *const err_msg = "Fail to init_context! Bench aborted."; + LOGe(err_msg); + return env->NewStringUTF(err_msg); + } + + auto pp_avg = 0.0; + auto tg_avg = 0.0; + auto pp_std = 0.0; + auto tg_std = 0.0; + + const uint32_t n_ctx = llama_n_ctx(context); + LOGi("n_ctx = %d", n_ctx); + + int i, j; + int nri; + for (nri = 0; nri < nr; nri++) { + LOGi("Benchmark prompt processing (pp = %d)", pp); + + common_batch_clear(g_batch); + + const int n_tokens = pp; + for (i = 0; i < n_tokens; i++) { + common_batch_add(g_batch, 0, i, {0}, false); + } + + g_batch.logits[g_batch.n_tokens - 1] = true; + llama_memory_clear(llama_get_memory(context), false); + + const auto t_pp_start = ggml_time_us(); + if (llama_decode(context, g_batch) != 0) { + LOGe("llama_decode() failed during prompt processing"); + } + const auto t_pp_end = ggml_time_us(); + + // bench text generation + + LOGi("Benchmark text generation (tg = %d)", tg); + + llama_memory_clear(llama_get_memory(context), false); + const auto t_tg_start = ggml_time_us(); + for (i = 0; i < tg; i++) { + common_batch_clear(g_batch); + for (j = 0; j < pl; j++) { + common_batch_add(g_batch, 0, i, {j}, true); + } + + if (llama_decode(context, g_batch) != 0) { + LOGe("llama_decode() failed during text generation"); + } + } + const auto t_tg_end = ggml_time_us(); + + llama_memory_clear(llama_get_memory(context), false); + + const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; + const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; + + const auto speed_pp = double(pp) / t_pp; + const auto speed_tg = double(pl * tg) / t_tg; + + pp_avg += speed_pp; + tg_avg += speed_tg; + + pp_std += speed_pp * speed_pp; + tg_std += speed_tg * speed_tg; + + LOGi("pp %f t/s, tg %f t/s", speed_pp, speed_tg); + } + + llama_free(context); + + pp_avg /= double(nr); + tg_avg /= double(nr); + + if (nr > 1) { + pp_std = sqrt(pp_std / double(nr - 1) - pp_avg * pp_avg * double(nr) / double(nr - 1)); + tg_std = sqrt(tg_std / double(nr - 1) - tg_avg * tg_avg * double(nr) / double(nr - 1)); + } else { + pp_std = 0; + tg_std = 0; + } + + char model_desc[128]; + llama_model_desc(g_model, model_desc, sizeof(model_desc)); + + const auto model_size = double(llama_model_size(g_model)) / 1024.0 / 1024.0 / 1024.0; + const auto model_n_params = double(llama_model_n_params(g_model)) / 1e9; + + const auto backend = get_backend(); + std::stringstream result; + result << std::setprecision(3); + result << "| model | size | params | backend | test | t/s |\n"; + result << "| --- | --- | --- | --- | --- | --- |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | pp " << pp << " | " << pp_avg << " ± " << pp_std << " |\n"; + result << "| " << model_desc << " | " << model_size << "GiB | " << model_n_params << "B | " + << backend << " | tg " << tg << " | " << tg_avg << " ± " << tg_std << " |\n"; + return env->NewStringUTF(result.str().c_str()); +} + + +/** + * Completion loop's long-term states: + * - chat management + * - position tracking + */ +constexpr const char *ROLE_SYSTEM = "system"; +constexpr const char *ROLE_USER = "user"; +constexpr const char *ROLE_ASSISTANT = "assistant"; + +static std::vector<common_chat_msg> chat_msgs; +static llama_pos system_prompt_position; +static llama_pos current_position; + +static void reset_long_term_states(const bool clear_kv_cache = true) { + chat_msgs.clear(); + system_prompt_position = 0; + current_position = 0; + + if (clear_kv_cache) + llama_memory_clear(llama_get_memory(g_context), false); +} + +/** + * TODO-hyin: implement sliding-window version as a better alternative + * + * Context shifting by discarding the older half of the tokens appended after system prompt: + * - take the [system_prompt_position] first tokens from the original prompt + * - take half of the last (system_prompt_position - system_prompt_position) tokens + * - recompute the logits in batches + */ +static void shift_context() { + const int n_discard = (current_position - system_prompt_position) / 2; + LOGi("%s: Discarding %d tokens", __func__, n_discard); + llama_memory_seq_rm(llama_get_memory(g_context), 0, system_prompt_position, system_prompt_position + n_discard); + llama_memory_seq_add(llama_get_memory(g_context), 0, system_prompt_position + n_discard, current_position, -n_discard); + current_position -= n_discard; + LOGi("%s: Context shifting done! Current position: %d", __func__, current_position); +} + +static std::string chat_add_and_format(const std::string &role, const std::string &content) { + common_chat_msg new_msg; + new_msg.role = role; + new_msg.content = content; + auto formatted = common_chat_format_single( + g_chat_templates.get(), chat_msgs, new_msg, role == ROLE_USER, /* use_jinja */ false); + chat_msgs.push_back(new_msg); + LOGi("%s: Formatted and added %s message: \n%s\n", __func__, role.c_str(), formatted.c_str()); + return formatted; +} + +/** + * Completion loop's short-term states: + * - stop generation position + * - token chars caching + * - current assistant message being generated + */ +static llama_pos stop_generation_position; +static std::string cached_token_chars; +static std::ostringstream assistant_ss; + +static void reset_short_term_states() { + stop_generation_position = 0; + cached_token_chars.clear(); + assistant_ss.str(""); +} + +static int decode_tokens_in_batches( + llama_context *context, + llama_batch &batch, + const llama_tokens &tokens, + const llama_pos start_pos, + const bool compute_last_logit = false) { + // Process tokens in batches using the global batch + LOGd("%s: Decode %d tokens starting at position %d", __func__, (int) tokens.size(), start_pos); + for (int i = 0; i < (int) tokens.size(); i += BATCH_SIZE) { + const int cur_batch_size = std::min((int) tokens.size() - i, BATCH_SIZE); + common_batch_clear(batch); + LOGv("%s: Preparing a batch size of %d starting at: %d", __func__, cur_batch_size, i); + + // Shift context if current batch cannot fit into the context + if (start_pos + i + cur_batch_size >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("%s: Current batch won't fit into context! Shifting...", __func__); + shift_context(); + } + + // Add tokens to the batch with proper positions + for (int j = 0; j < cur_batch_size; j++) { + const llama_token token_id = tokens[i + j]; + const llama_pos position = start_pos + i + j; + const bool want_logit = compute_last_logit && (i + j == tokens.size() - 1); + common_batch_add(batch, token_id, position, {0}, want_logit); + } + + // Decode this batch + const int decode_result = llama_decode(context, batch); + if (decode_result) { + LOGe("%s: llama_decode failed w/ %d", __func__, decode_result); + return 1; + } + } + return 0; +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_processSystemPrompt( + JNIEnv *env, + jobject /*unused*/, + jstring jsystem_prompt +) { + // Reset long-term & short-term states + reset_long_term_states(); + reset_short_term_states(); + + // Obtain system prompt from JEnv + const auto *system_prompt = env->GetStringUTFChars(jsystem_prompt, nullptr); + LOGd("%s: System prompt received: \n%s", __func__, system_prompt); + std::string formatted_system_prompt(system_prompt); + env->ReleaseStringUTFChars(jsystem_prompt, system_prompt); + + // Format system prompt if applicable + const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); + if (has_chat_template) { + formatted_system_prompt = chat_add_and_format(ROLE_SYSTEM, system_prompt); + } + + // Tokenize system prompt + const auto system_tokens = common_tokenize(g_context, formatted_system_prompt, + has_chat_template, has_chat_template); + for (auto id: system_tokens) { + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); + } + + // Handle context overflow + const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; + if ((int) system_tokens.size() > max_batch_size) { + LOGe("%s: System prompt too long for context! %d tokens, max: %d", + __func__, (int) system_tokens.size(), max_batch_size); + return 1; + } + + // Decode system tokens in batches + if (decode_tokens_in_batches(g_context, g_batch, system_tokens, current_position)) { + LOGe("%s: llama_decode() failed!", __func__); + return 2; + } + + // Update position + system_prompt_position = current_position = (int) system_tokens.size(); + return 0; +} + +extern "C" +JNIEXPORT jint JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_processUserPrompt( + JNIEnv *env, + jobject /*unused*/, + jstring juser_prompt, + jint n_predict +) { + // Reset short-term states + reset_short_term_states(); + + // Obtain and tokenize user prompt + const auto *const user_prompt = env->GetStringUTFChars(juser_prompt, nullptr); + LOGd("%s: User prompt received: \n%s", __func__, user_prompt); + std::string formatted_user_prompt(user_prompt); + env->ReleaseStringUTFChars(juser_prompt, user_prompt); + + // Format user prompt if applicable + const bool has_chat_template = common_chat_templates_was_explicit(g_chat_templates.get()); + if (has_chat_template) { + formatted_user_prompt = chat_add_and_format(ROLE_USER, user_prompt); + } + + // Decode formatted user prompts + auto user_tokens = common_tokenize(g_context, formatted_user_prompt, has_chat_template, has_chat_template); + for (auto id: user_tokens) { + LOGv("token: `%s`\t -> `%d`", common_token_to_piece(g_context, id).c_str(), id); + } + + // Ensure user prompt doesn't exceed the context size by truncating if necessary. + const int user_prompt_size = (int) user_tokens.size(); + const int max_batch_size = DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM; + if (user_prompt_size > max_batch_size) { + const int skipped_tokens = user_prompt_size - max_batch_size; + user_tokens.resize(max_batch_size); + LOGw("%s: User prompt too long! Skipped %d tokens!", __func__, skipped_tokens); + } + + // Decode user tokens in batches + if (decode_tokens_in_batches(g_context, g_batch, user_tokens, current_position, true)) { + LOGe("%s: llama_decode() failed!", __func__); + return 2; + } + + // Update position + current_position += user_prompt_size; + stop_generation_position = current_position + user_prompt_size + n_predict; + return 0; +} + +static bool is_valid_utf8(const char *string) { + if (!string) { return true; } + + const auto *bytes = (const unsigned char *) string; + int num; + + while (*bytes != 0x00) { + if ((*bytes & 0x80) == 0x00) { + // U+0000 to U+007F + num = 1; + } else if ((*bytes & 0xE0) == 0xC0) { + // U+0080 to U+07FF + num = 2; + } else if ((*bytes & 0xF0) == 0xE0) { + // U+0800 to U+FFFF + num = 3; + } else if ((*bytes & 0xF8) == 0xF0) { + // U+10000 to U+10FFFF + num = 4; + } else { + return false; + } + + bytes += 1; + for (int i = 1; i < num; ++i) { + if ((*bytes & 0xC0) != 0x80) { + return false; + } + bytes += 1; + } + } + return true; +} + +extern "C" +JNIEXPORT jstring JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_generateNextToken( + JNIEnv *env, + jobject /*unused*/ +) { + // Infinite text generation via context shifting + if (current_position >= DEFAULT_CONTEXT_SIZE - OVERFLOW_HEADROOM) { + LOGw("%s: Context full! Shifting...", __func__); + shift_context(); + } + + // Stop if reaching the marked position + if (current_position >= stop_generation_position) { + LOGw("%s: STOP: hitting stop position: %d", __func__, stop_generation_position); + return nullptr; + } + + // Sample next token + const auto new_token_id = common_sampler_sample(g_sampler, g_context, -1); + common_sampler_accept(g_sampler, new_token_id, true); + + // Populate the batch with new token, then decode + common_batch_clear(g_batch); + common_batch_add(g_batch, new_token_id, current_position, {0}, true); + if (llama_decode(g_context, g_batch) != 0) { + LOGe("%s: llama_decode() failed for generated token", __func__); + return nullptr; + } + + // Update position + current_position++; + + // Stop if next token is EOG + if (llama_vocab_is_eog(llama_model_get_vocab(g_model), new_token_id)) { + LOGd("id: %d,\tIS EOG!\nSTOP.", new_token_id); + chat_add_and_format(ROLE_ASSISTANT, assistant_ss.str()); + return nullptr; + } + + // If not EOG, convert to text + auto new_token_chars = common_token_to_piece(g_context, new_token_id); + cached_token_chars += new_token_chars; + + // Create and return a valid UTF-8 Java string + jstring result = nullptr; + if (is_valid_utf8(cached_token_chars.c_str())) { + result = env->NewStringUTF(cached_token_chars.c_str()); + LOGv("id: %d,\tcached: `%s`,\tnew: `%s`", new_token_id, cached_token_chars.c_str(), new_token_chars.c_str()); + + assistant_ss << cached_token_chars; + cached_token_chars.clear(); + } else { + LOGv("id: %d,\tappend to cache", new_token_id); + result = env->NewStringUTF(""); + } + return result; +} + + +extern "C" +JNIEXPORT void JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_unload(JNIEnv * /*unused*/, jobject /*unused*/) { + // Reset long-term & short-term states + reset_long_term_states(); + reset_short_term_states(); + + // Free up resources + common_sampler_free(g_sampler); + g_chat_templates.reset(); + llama_batch_free(g_batch); + llama_free(g_context); + llama_model_free(g_model); +} + +extern "C" +JNIEXPORT void JNICALL +Java_com_arm_aichat_internal_InferenceEngineImpl_shutdown(JNIEnv *, jobject /*unused*/) { + llama_backend_free(); +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/cpp/logging.h b/llama.cpp/examples/llama.android/lib/src/main/cpp/logging.h new file mode 100644 index 0000000..2e768d2 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/cpp/logging.h @@ -0,0 +1,61 @@ +// +// Created by Han Yin on 10/31/25. +// + +#ifndef AICHAT_LOGGING_H +#define AICHAT_LOGGING_H + +#endif //AICHAT_LOGGING_H + +#pragma once +#include <android/log.h> + +#ifndef LOG_TAG +#define LOG_TAG "ai-chat" +#endif + +#ifndef LOG_MIN_LEVEL +#if defined(NDEBUG) +#define LOG_MIN_LEVEL ANDROID_LOG_INFO +#else +#define LOG_MIN_LEVEL ANDROID_LOG_VERBOSE +#endif +#endif + +static inline int ai_should_log(int prio) { + return __android_log_is_loggable(prio, LOG_TAG, LOG_MIN_LEVEL); +} + +#if LOG_MIN_LEVEL <= ANDROID_LOG_VERBOSE +#define LOGv(...) do { if (ai_should_log(ANDROID_LOG_VERBOSE)) __android_log_print(ANDROID_LOG_VERBOSE, LOG_TAG, __VA_ARGS__); } while (0) +#else +#define LOGv(...) ((void)0) +#endif + +#if LOG_MIN_LEVEL <= ANDROID_LOG_DEBUG +#define LOGd(...) do { if (ai_should_log(ANDROID_LOG_DEBUG)) __android_log_print(ANDROID_LOG_DEBUG, LOG_TAG, __VA_ARGS__); } while (0) +#else +#define LOGd(...) ((void)0) +#endif + +#define LOGi(...) do { if (ai_should_log(ANDROID_LOG_INFO )) __android_log_print(ANDROID_LOG_INFO , LOG_TAG, __VA_ARGS__); } while (0) +#define LOGw(...) do { if (ai_should_log(ANDROID_LOG_WARN )) __android_log_print(ANDROID_LOG_WARN , LOG_TAG, __VA_ARGS__); } while (0) +#define LOGe(...) do { if (ai_should_log(ANDROID_LOG_ERROR)) __android_log_print(ANDROID_LOG_ERROR, LOG_TAG, __VA_ARGS__); } while (0) + +static inline int android_log_prio_from_ggml(enum ggml_log_level level) { + switch (level) { + case GGML_LOG_LEVEL_ERROR: return ANDROID_LOG_ERROR; + case GGML_LOG_LEVEL_WARN: return ANDROID_LOG_WARN; + case GGML_LOG_LEVEL_INFO: return ANDROID_LOG_INFO; + case GGML_LOG_LEVEL_DEBUG: return ANDROID_LOG_DEBUG; + default: return ANDROID_LOG_DEFAULT; + } +} + +static inline void aichat_android_log_callback(enum ggml_log_level level, + const char* text, + void* /*user*/) { + const int prio = android_log_prio_from_ggml(level); + if (!ai_should_log(prio)) return; + __android_log_write(prio, LOG_TAG, text); +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt new file mode 100644 index 0000000..b72a24e --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/AiChat.kt @@ -0,0 +1,14 @@ +package com.arm.aichat + +import android.content.Context +import com.arm.aichat.internal.InferenceEngineImpl + +/** + * Main entry point for Arm's AI Chat library. + */ +object AiChat { + /** + * Get the inference engine single instance. + */ + fun getInferenceEngine(context: Context) = InferenceEngineImpl.getInstance(context) +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt new file mode 100644 index 0000000..26c1668 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/InferenceEngine.kt @@ -0,0 +1,89 @@ +package com.arm.aichat + +import com.arm.aichat.InferenceEngine.State +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.StateFlow + +/** + * Interface defining the core LLM inference operations. + */ +interface InferenceEngine { + /** + * Current state of the inference engine + */ + val state: StateFlow<State> + + /** + * Load a model from the given path. + * + * @throws UnsupportedArchitectureException if model architecture not supported + */ + suspend fun loadModel(pathToModel: String) + + /** + * Sends a system prompt to the loaded model + */ + suspend fun setSystemPrompt(systemPrompt: String) + + /** + * Sends a user prompt to the loaded model and returns a Flow of generated tokens. + */ + fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String> + + /** + * Runs a benchmark with the specified parameters. + */ + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String + + /** + * Unloads the currently loaded model. + */ + fun cleanUp() + + /** + * Cleans up resources when the engine is no longer needed. + */ + fun destroy() + + /** + * States of the inference engine + */ + sealed class State { + object Uninitialized : State() + object Initializing : State() + object Initialized : State() + + object LoadingModel : State() + object UnloadingModel : State() + object ModelReady : State() + + object Benchmarking : State() + object ProcessingSystemPrompt : State() + object ProcessingUserPrompt : State() + + object Generating : State() + + data class Error(val exception: Exception) : State() + } + + companion object { + const val DEFAULT_PREDICT_LENGTH = 1024 + } +} + +val State.isUninterruptible + get() = this is State.Initializing || + this is State.LoadingModel || + this is State.UnloadingModel || + this is State.Benchmarking || + this is State.ProcessingSystemPrompt || + this is State.ProcessingUserPrompt + +val State.isModelLoaded: Boolean + get() = this is State.ModelReady || + this is State.Benchmarking || + this is State.ProcessingSystemPrompt || + this is State.ProcessingUserPrompt || + this is State.Generating + +class UnsupportedArchitectureException : Exception() diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt new file mode 100644 index 0000000..2f15eef --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/FileType.kt @@ -0,0 +1,61 @@ +package com.arm.aichat.gguf + +import kotlin.collections.get + + +/** + * Numerical codes used by `general.file_type` (see llama.cpp repo's `constants.py`). + * The `label` matches what llama‑cli prints. + */ +enum class FileType(val code: Int, val label: String) { + ALL_F32(0, "all F32"), + MOSTLY_F16(1, "F16"), + MOSTLY_Q4_0(2, "Q4_0"), + MOSTLY_Q4_1(3, "Q4_1"), + // 4 removed + MOSTLY_Q8_0(7, "Q8_0"), + MOSTLY_Q5_0(8, "Q5_0"), + MOSTLY_Q5_1(9, "Q5_1"), + + /* K‑quants ------------------------------------------------------------ */ + MOSTLY_Q2_K (10, "Q2_K - Medium"), + MOSTLY_Q3_K_S (11, "Q3_K - Small"), + MOSTLY_Q3_K_M (12, "Q3_K - Medium"), + MOSTLY_Q3_K_L (13, "Q3_K - Large"), + MOSTLY_Q4_K_S (14, "Q4_K - Small"), + MOSTLY_Q4_K_M (15, "Q4_K - Medium"), + MOSTLY_Q5_K_S (16, "Q5_K - Small"), + MOSTLY_Q5_K_M (17, "Q5_K - Medium"), + MOSTLY_Q6_K (18, "Q6_K"), + + /* IQ quants ----------------------------------------------------------- */ + MOSTLY_IQ2_XXS (19, "IQ2_XXS - 2.06 bpw"), + MOSTLY_IQ2_XS (20, "IQ2_XS - 2.31 bpw"), + MOSTLY_Q2_K_S (21, "Q2_K - Small"), + MOSTLY_IQ3_XS (22, "IQ3_XS - 3.30 bpw"), + MOSTLY_IQ3_XXS (23, "IQ3_XXS - 3.06 bpw"), + MOSTLY_IQ1_S (24, "IQ1_S - 1.56 bpw"), + MOSTLY_IQ4_NL (25, "IQ4_NL - 4.5 bpw"), + MOSTLY_IQ3_S (26, "IQ3_S - 3.44 bpw"), + MOSTLY_IQ3_M (27, "IQ3_M - 3.66 bpw"), + MOSTLY_IQ2_S (28, "IQ2_S - 2.50 bpw"), + MOSTLY_IQ2_M (29, "IQ2_M - 2.70 bpw"), + MOSTLY_IQ4_XS (30, "IQ4_XS - 4.25 bpw"), + MOSTLY_IQ1_M (31, "IQ1_M - 1.75 bpw"), + + /* BF16 & Ternary ------------------------------------------------------ */ + MOSTLY_BF16 (32, "BF16"), + MOSTLY_TQ1_0 (36, "TQ1_0 - 1.69 bpw ternary"), + MOSTLY_TQ2_0 (37, "TQ2_0 - 2.06 bpw ternary"), + + /* Special flag -------------------------------------------------------- */ + GUESSED(1024, "(guessed)"), + + UNKNOWN(-1, "unknown"); + + companion object { + private val map = entries.associateBy(FileType::code) + + fun fromCode(code: Int?): FileType = map[code] ?: UNKNOWN + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt new file mode 100644 index 0000000..5e1971a --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadata.kt @@ -0,0 +1,132 @@ +package com.arm.aichat.gguf + +import java.io.IOException + + +/** + * Structured metadata of GGUF + */ +data class GgufMetadata( + // Basic file info + val version: GgufVersion, + val tensorCount: Long, + val kvCount: Long, + + // General info + val basic: BasicInfo, + val author: AuthorInfo? = null, + val additional: AdditionalInfo? = null, + val architecture: ArchitectureInfo? = null, + val baseModels: List<BaseModelInfo>? = null, + val tokenizer: TokenizerInfo? = null, + + // Derivative info + val dimensions: DimensionsInfo? = null, + val attention: AttentionInfo? = null, + val rope: RopeInfo? = null, + val experts: ExpertsInfo? = null +) { + enum class GgufVersion(val code: Int, val label: String) { + /** First public draft; little‑endian only, no alignment key. */ + LEGACY_V1(1, "Legacy v1"), + + /** Added split‑file support and some extra metadata keys. */ + EXTENDED_V2(2, "Extended v2"), + + /** Current spec: endian‑aware, mandatory alignment, fully validated. */ + VALIDATED_V3(3, "Validated v3"); + + companion object { + fun fromCode(code: Int): GgufVersion = + entries.firstOrNull { it.code == code } + ?: throw IOException("Unknown GGUF version code $code") + } + + override fun toString(): String = "$label (code=$code)" + } + + data class BasicInfo( + val uuid: String? = null, + val name: String? = null, + val nameLabel: String? = null, + val sizeLabel: String? = null, // Size label like "7B" + ) + + data class AuthorInfo( + val organization: String? = null, + val author: String? = null, + val doi: String? = null, + val url: String? = null, + val repoUrl: String? = null, + val license: String? = null, + val licenseLink: String? = null, + ) + + data class AdditionalInfo( + val type: String? = null, + val description: String? = null, + val tags: List<String>? = null, + val languages: List<String>? = null, + ) + + data class ArchitectureInfo( + val architecture: String? = null, + val fileType: Int? = null, + val vocabSize: Int? = null, + val finetune: String? = null, + val quantizationVersion: Int? = null, + ) + + data class BaseModelInfo( + val name: String? = null, + val author: String? = null, + val version: String? = null, + val organization: String? = null, + val url: String? = null, + val doi: String? = null, + val uuid: String? = null, + val repoUrl: String? = null, + ) + + data class TokenizerInfo( + val model: String? = null, + val bosTokenId: Int? = null, + val eosTokenId: Int? = null, + val unknownTokenId: Int? = null, + val paddingTokenId: Int? = null, + val addBosToken: Boolean? = null, + val addEosToken: Boolean? = null, + val chatTemplate: String? = null, + ) + + data class DimensionsInfo( + val contextLength: Int? = null, + val embeddingSize: Int? = null, + val blockCount: Int? = null, + val feedForwardSize: Int? = null, + ) + + data class AttentionInfo( + val headCount: Int? = null, + val headCountKv: Int? = null, + val keyLength: Int? = null, + val valueLength: Int? = null, + val layerNormEpsilon: Float? = null, + val layerNormRmsEpsilon: Float? = null, + ) + + data class RopeInfo( + val frequencyBase: Float? = null, + val dimensionCount: Int? = null, + val scalingType: String? = null, + val scalingFactor: Float? = null, + val attnFactor: Float? = null, + val originalContextLength: Int? = null, + val finetuned: Boolean? = null, + ) + + data class ExpertsInfo( + val count: Int? = null, + val usedCount: Int? = null, + ) +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt new file mode 100644 index 0000000..264a6c0 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/gguf/GgufMetadataReader.kt @@ -0,0 +1,77 @@ +package com.arm.aichat.gguf + +import android.content.Context +import android.net.Uri +import com.arm.aichat.internal.gguf.GgufMetadataReaderImpl +import java.io.File +import java.io.IOException +import java.io.InputStream + +/** + * Interface for reading GGUF metadata from model files. + * Use `GgufMetadataReader.create()` to get an instance. + */ +interface GgufMetadataReader { + /** + * Reads the magic number from the specified file path. + * + * @param file Java File to the GGUF file with absolute path + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(file: File): Boolean + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining [android.content.ContentProvider] + * @param uri Uri to the GGUF file provided by [android.content.ContentProvider] + * @return true if file is valid GGUF, otherwise false + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean + + /** + * Reads and parses GGUF metadata from the specified file path. + * + * @param input the [InputStream] obtained from a readable file or content + * @return Structured metadata extracted from the file + * @throws IOException if file is damaged or cannot be read + * @throws InvalidFileFormatException if file format is invalid + */ + suspend fun readStructuredMetadata(input: InputStream): GgufMetadata + + companion object { + private val DEFAULT_SKIP_KEYS = setOf( + "tokenizer.chat_template", + "tokenizer.ggml.scores", + "tokenizer.ggml.tokens", + "tokenizer.ggml.token_type" + ) + + /** + * Creates a default GgufMetadataReader instance + */ + fun create(): GgufMetadataReader = GgufMetadataReaderImpl( + skipKeys = DEFAULT_SKIP_KEYS, + arraySummariseThreshold = 1_000 + ) + + /** + * Creates a GgufMetadataReader with custom configuration + * + * @param skipKeys Keys whose value should be skipped entirely (not kept in the result map) + * @param arraySummariseThreshold If ≥0, arrays longer get summarised, not materialised; + * If -1, never summarise. + */ + fun create( + skipKeys: Set<String> = DEFAULT_SKIP_KEYS, + arraySummariseThreshold: Int = 1_000 + ): GgufMetadataReader = GgufMetadataReaderImpl( + skipKeys = skipKeys, + arraySummariseThreshold = arraySummariseThreshold + ) + } +} + +class InvalidFileFormatException : IOException() diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt new file mode 100644 index 0000000..a293796 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/InferenceEngineImpl.kt @@ -0,0 +1,324 @@ +package com.arm.aichat.internal + +import android.content.Context +import android.util.Log +import com.arm.aichat.InferenceEngine +import com.arm.aichat.UnsupportedArchitectureException +import com.arm.aichat.internal.InferenceEngineImpl.Companion.getInstance +import dalvik.annotation.optimization.FastNative +import kotlinx.coroutines.CancellationException +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.flow.Flow +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withContext +import java.io.File +import java.io.IOException + +/** + * JNI wrapper for the llama.cpp library providing Android-friendly access to large language models. + * + * This class implements a singleton pattern for managing the lifecycle of a single LLM instance. + * All operations are executed on a dedicated single-threaded dispatcher to ensure thread safety + * with the underlying C++ native code. + * + * The typical usage flow is: + * 1. Get instance via [getInstance] + * 2. Load a model with [loadModel] + * 3. Send prompts with [sendUserPrompt] + * 4. Generate responses as token streams + * 5. Perform [cleanUp] when done with a model + * 6. Properly [destroy] when completely done + * + * State transitions are managed automatically and validated at each operation. + * + * @see ai_chat.cpp for the native implementation details + */ +internal class InferenceEngineImpl private constructor( + private val nativeLibDir: String +) : InferenceEngine { + + companion object { + private val TAG = InferenceEngineImpl::class.java.simpleName + + @Volatile + private var instance: InferenceEngine? = null + + /** + * Create or obtain [InferenceEngineImpl]'s single instance. + * + * @param Context for obtaining native library directory + * @throws IllegalArgumentException if native library path is invalid + * @throws UnsatisfiedLinkError if library failed to load + */ + internal fun getInstance(context: Context) = + instance ?: synchronized(this) { + val nativeLibDir = context.applicationInfo.nativeLibraryDir + require(nativeLibDir.isNotBlank()) { "Expected a valid native library path!" } + + try { + Log.i(TAG, "Instantiating InferenceEngineImpl,,,") + InferenceEngineImpl(nativeLibDir).also { instance = it } + } catch (e: UnsatisfiedLinkError) { + Log.e(TAG, "Failed to load native library from $nativeLibDir", e) + throw e + } + } + } + + /** + * JNI methods + * @see ai_chat.cpp + */ + @FastNative + private external fun init(nativeLibDir: String) + + @FastNative + private external fun load(modelPath: String): Int + + @FastNative + private external fun prepare(): Int + + @FastNative + private external fun systemInfo(): String + + @FastNative + private external fun benchModel(pp: Int, tg: Int, pl: Int, nr: Int): String + + @FastNative + private external fun processSystemPrompt(systemPrompt: String): Int + + @FastNative + private external fun processUserPrompt(userPrompt: String, predictLength: Int): Int + + @FastNative + private external fun generateNextToken(): String? + + @FastNative + private external fun unload() + + @FastNative + private external fun shutdown() + + private val _state = + MutableStateFlow<InferenceEngine.State>(InferenceEngine.State.Uninitialized) + override val state: StateFlow<InferenceEngine.State> = _state.asStateFlow() + + private var _readyForSystemPrompt = false + @Volatile + private var _cancelGeneration = false + + /** + * Single-threaded coroutine dispatcher & scope for LLama asynchronous operations + */ + @OptIn(ExperimentalCoroutinesApi::class) + private val llamaDispatcher = Dispatchers.IO.limitedParallelism(1) + private val llamaScope = CoroutineScope(llamaDispatcher + SupervisorJob()) + + init { + llamaScope.launch { + try { + check(_state.value is InferenceEngine.State.Uninitialized) { + "Cannot load native library in ${_state.value.javaClass.simpleName}!" + } + _state.value = InferenceEngine.State.Initializing + Log.i(TAG, "Loading native library...") + System.loadLibrary("ai-chat") + init(nativeLibDir) + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "Native library loaded! System info: \n${systemInfo()}") + + } catch (e: Exception) { + Log.e(TAG, "Failed to load native library", e) + throw e + } + } + } + + /** + * Load the LLM + */ + override suspend fun loadModel(pathToModel: String) = + withContext(llamaDispatcher) { + check(_state.value is InferenceEngine.State.Initialized) { + "Cannot load model in ${_state.value.javaClass.simpleName}!" + } + + try { + Log.i(TAG, "Checking access to model file... \n$pathToModel") + File(pathToModel).let { + require(it.exists()) { "File not found" } + require(it.isFile) { "Not a valid file" } + require(it.canRead()) { "Cannot read file" } + } + + Log.i(TAG, "Loading model... \n$pathToModel") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.LoadingModel + load(pathToModel).let { + // TODO-han.yin: find a better way to pass other error codes + if (it != 0) throw UnsupportedArchitectureException() + } + prepare().let { + if (it != 0) throw IOException("Failed to prepare resources") + } + Log.i(TAG, "Model loaded!") + _readyForSystemPrompt = true + + _cancelGeneration = false + _state.value = InferenceEngine.State.ModelReady + } catch (e: Exception) { + Log.e(TAG, (e.message ?: "Error loading model") + "\n" + pathToModel, e) + _state.value = InferenceEngine.State.Error(e) + throw e + } + } + + /** + * Process the plain text system prompt + * + * TODO-han.yin: return error code if system prompt not correct processed? + */ + override suspend fun setSystemPrompt(prompt: String) = + withContext(llamaDispatcher) { + require(prompt.isNotBlank()) { "Cannot process empty system prompt!" } + check(_readyForSystemPrompt) { "System prompt must be set ** RIGHT AFTER ** model loaded!" } + check(_state.value is InferenceEngine.State.ModelReady) { + "Cannot process system prompt in ${_state.value.javaClass.simpleName}!" + } + + Log.i(TAG, "Sending system prompt...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.ProcessingSystemPrompt + processSystemPrompt(prompt).let { result -> + if (result != 0) { + RuntimeException("Failed to process system prompt: $result").also { + _state.value = InferenceEngine.State.Error(it) + throw it + } + } + } + Log.i(TAG, "System prompt processed! Awaiting user prompt...") + _state.value = InferenceEngine.State.ModelReady + } + + /** + * Send plain text user prompt to LLM, which starts generating tokens in a [Flow] + */ + override fun sendUserPrompt( + message: String, + predictLength: Int, + ): Flow<String> = flow { + require(message.isNotEmpty()) { "User prompt discarded due to being empty!" } + check(_state.value is InferenceEngine.State.ModelReady) { + "User prompt discarded due to: ${_state.value.javaClass.simpleName}" + } + + try { + Log.i(TAG, "Sending user prompt...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.ProcessingUserPrompt + + processUserPrompt(message, predictLength).let { result -> + if (result != 0) { + Log.e(TAG, "Failed to process user prompt: $result") + return@flow + } + } + + Log.i(TAG, "User prompt processed. Generating assistant prompt...") + _state.value = InferenceEngine.State.Generating + while (!_cancelGeneration) { + generateNextToken()?.let { utf8token -> + if (utf8token.isNotEmpty()) emit(utf8token) + } ?: break + } + if (_cancelGeneration) { + Log.i(TAG, "Assistant generation aborted per requested.") + } else { + Log.i(TAG, "Assistant generation complete. Awaiting user prompt...") + } + _state.value = InferenceEngine.State.ModelReady + } catch (e: CancellationException) { + Log.i(TAG, "Assistant generation's flow collection cancelled.") + _state.value = InferenceEngine.State.ModelReady + throw e + } catch (e: Exception) { + Log.e(TAG, "Error during generation!", e) + _state.value = InferenceEngine.State.Error(e) + throw e + } + }.flowOn(llamaDispatcher) + + /** + * Benchmark the model + */ + override suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int): String = + withContext(llamaDispatcher) { + check(_state.value is InferenceEngine.State.ModelReady) { + "Benchmark request discarded due to: $state" + } + Log.i(TAG, "Start benchmark (pp: $pp, tg: $tg, pl: $pl, nr: $nr)") + _readyForSystemPrompt = false // Just to be safe + _state.value = InferenceEngine.State.Benchmarking + benchModel(pp, tg, pl, nr).also { + _state.value = InferenceEngine.State.ModelReady + } + } + + /** + * Unloads the model and frees resources, or reset error states + */ + override fun cleanUp() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { + when (val state = _state.value) { + is InferenceEngine.State.ModelReady -> { + Log.i(TAG, "Unloading model and free resources...") + _readyForSystemPrompt = false + _state.value = InferenceEngine.State.UnloadingModel + + unload() + + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "Model unloaded!") + Unit + } + + is InferenceEngine.State.Error -> { + Log.i(TAG, "Resetting error states...") + _state.value = InferenceEngine.State.Initialized + Log.i(TAG, "States reset!") + Unit + } + + else -> throw IllegalStateException("Cannot unload model in ${state.javaClass.simpleName}") + } + } + } + + /** + * Cancel all ongoing coroutines and free GGML backends + */ + override fun destroy() { + _cancelGeneration = true + runBlocking(llamaDispatcher) { + _readyForSystemPrompt = false + when(_state.value) { + is InferenceEngine.State.Uninitialized -> {} + is InferenceEngine.State.Initialized -> shutdown() + else -> { unload(); shutdown() } + } + } + llamaScope.cancel() + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt new file mode 100644 index 0000000..bf250ac --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/main/java/com/arm/aichat/internal/gguf/GgufMetadataReaderImpl.kt @@ -0,0 +1,590 @@ +package com.arm.aichat.internal.gguf + +import android.content.Context +import android.net.Uri +import com.arm.aichat.gguf.GgufMetadata +import com.arm.aichat.gguf.GgufMetadataReader +import com.arm.aichat.gguf.InvalidFileFormatException +import java.io.File +import java.io.IOException +import java.io.InputStream + + +/** + * Utility class to read GGUF model files and extract metadata key-value pairs. + * This parser reads the header and metadata of a GGUF v3 file (little-endian) and skips tensor data. + */ +internal class GgufMetadataReaderImpl( + private val skipKeys: Set<String>, + private val arraySummariseThreshold: Int, +) : GgufMetadataReader { + companion object { + private const val ARCH_LLAMA = "llama" + } + + /** Enum corresponding to GGUF metadata value types (for convenience and array element typing). */ + enum class MetadataType(val code: Int) { + UINT8(0), INT8(1), UINT16(2), INT16(3), + UINT32(4), INT32(5), FLOAT32(6), BOOL(7), + STRING(8), ARRAY(9), UINT64(10), INT64(11), FLOAT64(12); + companion object { + private val codeMap = entries.associateBy(MetadataType::code) + fun fromCode(code: Int): MetadataType = codeMap[code] + ?: throw IOException("Unknown metadata value type code: $code") + } + } + + /** Sealed class hierarchy for metadata values, providing type-safe representations for each GGUF metadata type. */ + sealed class MetadataValue { + data class UInt8(val value: UByte) : MetadataValue() // 0: 8-bit unsigned int + data class Int8(val value: Byte) : MetadataValue() // 1: 8-bit signed int + data class UInt16(val value: UShort) : MetadataValue() // 2: 16-bit unsigned int (little-endian) + data class Int16(val value: Short) : MetadataValue() // 3: 16-bit signed int (little-endian) + data class UInt32(val value: UInt) : MetadataValue() // 4: 32-bit unsigned int (little-endian) + data class Int32(val value: Int) : MetadataValue() // 5: 32-bit signed int (little-endian) + data class Float32(val value: Float) : MetadataValue() // 6: 32-bit IEEE754 float + data class Bool(val value: Boolean) : MetadataValue() // 7: Boolean (1-byte, 0=false, 1=true) + data class StringVal(val value: String) : MetadataValue() // 8: UTF-8 string (length-prefixed) + data class ArrayVal(val elementType: MetadataType, val elements: List<MetadataValue>) : MetadataValue() + data class UInt64(val value: ULong) : MetadataValue() // 10: 64-bit unsigned int (little-endian) + data class Int64(val value: Long) : MetadataValue() // 11: 64-bit signed int (little-endian) + data class Float64(val value: Double) : MetadataValue() // 12: 64-bit IEEE754 double + } + + /* Convert MetadataValue to plain Kotlin primitives for allMetadata map */ + private fun MetadataValue.toPrimitive(): Any = when (this) { + is MetadataValue.UInt8 -> value + is MetadataValue.Int8 -> value + is MetadataValue.UInt16 -> value + is MetadataValue.Int16 -> value + is MetadataValue.UInt32 -> value + is MetadataValue.Int32 -> value + is MetadataValue.Float32 -> value + is MetadataValue.Bool -> value + is MetadataValue.StringVal -> value + is MetadataValue.UInt64 -> value + is MetadataValue.Int64 -> value + is MetadataValue.Float64 -> value + is MetadataValue.ArrayVal -> elements.map { it.toPrimitive() } + } + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(file: File): Boolean = + file.inputStream().buffered().use { ensureMagic(it) } + + /** + * Reads the magic number from the specified file path. + * + * @param context Context for obtaining ContentResolver + * @param uri Uri to the GGUF file provided by ContentProvider + * @return true if file is valid GGUF, otherwise false + */ + override suspend fun ensureSourceFileFormat(context: Context, uri: Uri): Boolean = + context.contentResolver.openInputStream(uri)?.buffered()?.use { ensureMagic(it) } == true + + /** Reads the 4‑byte magic; throws if magic ≠ "GGUF". */ + private fun ensureMagic(input: InputStream): Boolean = + ByteArray(4).let { + if (input.read(it) != 4) throw IOException("Not a valid file!") + it.contentEquals(byteArrayOf(0x47, 0x47, 0x55, 0x46)) // "GGUF" + } + + /** + * High‑level entry point: parses a `.gguf` file on disk and returns the fully + * populated [GgufMetadata] tree. + * + * Steps performed internally: + * 1. Reads and validates the 8‑byte header (`"GGUF"` magic + version). + * 2. Streams through the key‑value section, skipping large blobs if the key + * appears in [skipKeys] or if an array exceeds [arraySummariseThreshold]. + * 3. Converts the resulting raw map into strongly‑typed sub‑structures + * (basic info, tokenizer, rope, etc.). + * + * The method is STREAMING‑ONLY: tensors are never mapped or loaded into + * memory, so even multi‑GB model files can be processed in < 50 ms. + * + * @param path Absolute or relative filesystem path to a `.gguf` file. + * @return A [GgufMetadata] instance containing all recognised metadata plus + * an `allMetadata` map with any keys that were not given a dedicated + * field. + * @throws IOException if the file is not GGUF, the version is unsupported, + * or the metadata block is truncated / corrupt. + */ + override suspend fun readStructuredMetadata(input: InputStream): GgufMetadata { + // ── 1. header ────────────────────────────────────────────────────────── + // throws on mismatch + val version = ensureMagicAndVersion(input) + val tensorCount = readLittleLong(input) + val kvCount = readLittleLong(input) + + // ── 2. metadata map (reuse our raw parser, but we need access to the stream) ── + val meta = readMetaMap(input, kvCount) // <String, MetadataValue> + + // ── 3. build structured object ──────────────────────────────────────── + return buildStructured(meta, version, tensorCount, kvCount) + } + + /** Reads the 4‑byte magic + 4‑byte version; throws if magic ≠ "GGUF". */ + private fun ensureMagicAndVersion(input: InputStream): GgufMetadata.GgufVersion { + if (!ensureMagic(input)) throw InvalidFileFormatException() + return GgufMetadata.GgufVersion.fromCode(readLEUInt32(input)) + } + + /** + * Read an unsigned 32‑bit little‑endian integer. + * + * @throws IOException if fewer than four bytes are available. + */ + private fun readLEUInt32(input: InputStream): Int { + val b0 = input.read(); val b1 = input.read(); val b2 = input.read(); val b3 = input.read() + if (b3 == -1) throw IOException("Unexpected EOF while reading UInt32") + return (b3 and 0xFF shl 24) or + (b2 and 0xFF shl 16) or + (b1 and 0xFF shl 8) or + (b0 and 0xFF) + } + + /** + * Low‑level helper that reads the entire “key-value” section from the current + * stream position. + * + * @param input Open stream positioned JUST AFTER the header. + * @param kvCnt Number of key‑value pairs (taken from the header). + * @return Mutable map with one [MetadataValue] for every key that is NOT skipped. + * + * The function honours [skipKeys] and [arraySummariseThreshold] by invoking + * [skipValue] or [parseValue] accordingly. + */ + private fun readMetaMap(input: InputStream, kvCnt: Long): Map<String, MetadataValue> = + mutableMapOf<String, MetadataValue>().apply { + repeat(kvCnt.toInt()) { + val key = readString(input) + val valueT = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + if (key in skipKeys) { + skipValue(input, valueT) + } else { + this[key] = parseValue(input, valueT) + } + } + } + + /** + * Converts a flat [Map]<[String], [MetadataValue]> into the strongly‑typed + * [GgufMetadata] tree used by the rest of the app. + * + * Only the keys listed in the spec are copied into dedicated data classes; + * everything else is preserved in `GgufMetadata.allMetadata`. + * + * @param m Raw key/value map. + * @param version GGUF file‑format version (enum). + * @param tensorCnt Number of tensors (from the header). + * @param kvCnt Total metadata pair count (from the header). + */ + private fun buildStructured( + m: Map<String, MetadataValue>, + version: GgufMetadata.GgufVersion, + tensorCnt: Long, + kvCnt: Long + ): GgufMetadata { + // ---------- helpers ---------- + fun String.str() = (m[this] as? MetadataValue.StringVal)?.value + fun String.bool() = (m[this] as? MetadataValue.Bool)?.value + fun String.i32() = (m[this] as? MetadataValue.Int32)?.value + fun String.u32() = (m[this] as? MetadataValue.UInt32)?.value?.toInt() + fun String.f32() = (m[this] as? MetadataValue.Float32)?.value + fun String.f64() = (m[this] as? MetadataValue.Float64)?.value?.toFloat() + fun String.strList(): List<String>? = + (m[this] as? MetadataValue.ArrayVal) + ?.elements + ?.mapNotNull { (it as? MetadataValue.StringVal)?.value } + + val arch = "general.architecture".str() ?: ARCH_LLAMA + + // -------------- populate sections ---------------- + val basic = GgufMetadata.BasicInfo( + uuid = "general.uuid".str(), + name = "general.basename".str(), + nameLabel = "general.name".str(), + sizeLabel = "general.size_label".str() + ) + + val author = GgufMetadata.AuthorInfo( + organization = "general.organization".str(), + author = "general.author".str(), + doi = "general.doi".str(), + url = "general.url".str(), + repoUrl = "general.repo_url".str(), + license = "general.license".str(), + licenseLink = "general.license.link".str() + ).takeUnless { + organization == null && author == null && doi == null && + url == null && repoUrl == null && license == null && licenseLink == null + } + + val additional = GgufMetadata.AdditionalInfo( + type = "general.type".str(), + description = "general.description".str(), + tags = "general.tags".strList(), + languages = "general.languages".strList() + ).takeUnless { + type == null && description == null && tags == null && languages == null + } + + val architectureInfo = GgufMetadata.ArchitectureInfo( + architecture = arch, + fileType = "general.file_type".u32(), + vocabSize = "$arch.vocab_size".u32(), + finetune = "general.finetune".str(), + quantizationVersion = "general.quantization_version".u32() + ).takeUnless { fileType == null && vocabSize == null && finetune == null && quantizationVersion == null } + + val baseModels = buildList { + val n = "general.base_model.count".u32() ?: 0 + for (i in 0 until n) { + fun k(s: String) = "general.base_model.$i.$s" + add( + GgufMetadata.BaseModelInfo( + name = k("name").str(), + author = k("author").str(), + version = k("version").str(), + organization = k("organization").str(), + url = k("url").str(), + doi = k("doi").str(), + uuid = k("uuid").str(), + repoUrl = k("repo_url").str(), + ) + ) + } + }.takeIf { it.isNotEmpty() } + + val tokenizer = GgufMetadata.TokenizerInfo( + model = "tokenizer.ggml.model".str(), + bosTokenId = "tokenizer.ggml.bos_token_id".u32(), + eosTokenId = "tokenizer.ggml.eos_token_id".u32(), + unknownTokenId = "tokenizer.ggml.unknown_token_id".u32(), + paddingTokenId = "tokenizer.ggml.padding_token_id".u32(), + addBosToken = "tokenizer.ggml.add_bos_token".bool(), + addEosToken = "tokenizer.ggml.add_eos_token".bool(), + chatTemplate = "tokenizer.chat_template".str() + ).takeUnless { model == null && bosTokenId == null && eosTokenId == null && + unknownTokenId == null && paddingTokenId == null && + addBosToken == null && addEosToken == null && chatTemplate == null + } + + val dimensions = GgufMetadata.DimensionsInfo( + contextLength = "$arch.context_length".u32(), + embeddingSize = "$arch.embedding_length".u32(), + blockCount = "$arch.block_count".u32(), + feedForwardSize = "$arch.feed_forward_length".u32() + ).takeUnless { contextLength == null && embeddingSize == null && blockCount == null && feedForwardSize == null } + + val attention = GgufMetadata.AttentionInfo( + headCount = "$arch.attention.head_count".u32(), + headCountKv = "$arch.attention.head_count_kv".u32(), + keyLength = "$arch.attention.key_length".u32(), + valueLength = "$arch.attention.value_length".u32(), + layerNormEpsilon = "$arch.attention.layer_norm_epsilon".f32(), + layerNormRmsEpsilon = "$arch.attention.layer_norm_rms_epsilon".f32(), + ).takeUnless { headCount == null && headCountKv == null && keyLength == null && valueLength == null && + layerNormEpsilon == null && layerNormRmsEpsilon == null + } + + val rope = GgufMetadata.RopeInfo( + frequencyBase = "$arch.rope.freq_base".f32(), + dimensionCount = "$arch.rope.dimension_count".u32(), + scalingType = "$arch.rope.scaling.type".str(), + scalingFactor = "$arch.rope.scaling.factor".f32(), + attnFactor = "$arch.rope.scaling.attn_factor".f32(), + originalContextLength = "$arch.rope.scaling.original_context_length".u32(), + finetuned = "$arch.rope.scaling.finetuned".bool() + ).takeUnless { frequencyBase == null && dimensionCount == null && + scalingType == null && scalingFactor == null && attnFactor == null && + originalContextLength == null && finetuned == null + } + + val experts = GgufMetadata.ExpertsInfo( + count = "$arch.expert_count".u32(), + usedCount = "$arch.expert_used_count".u32() + ).takeUnless { count == null && usedCount == null } + + return GgufMetadata( + version = version, + tensorCount = tensorCnt, + kvCount = kvCnt, + basic = basic, + author = author, + additional = additional, + architecture = architectureInfo, + baseModels = baseModels, + tokenizer = tokenizer, + dimensions = dimensions, + attention = attention, + rope = rope, + experts = experts + ) + } + + /** + * Recursively parses a metadata value of the given type from the input stream. + * @param input The input stream positioned at the start of the value. + * @param type The metadata value type to parse. + */ + private fun parseValue(input: InputStream, type: MetadataType): MetadataValue = when (type) { + MetadataType.UINT8 -> { + // 1-byte unsigned integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading uint8 value.") + MetadataValue.UInt8(byteVal.toUByte()) + } + MetadataType.INT8 -> { + // 1-byte signed integer + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading int8 value.") + MetadataValue.Int8(byteVal.toByte()) + } + MetadataType.UINT16 -> { + // 2-byte unsigned integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading uint16 value.") + // Combine two bytes (little-endian) into an unsigned 16-bit value + val u16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.UInt16(u16.toUShort()) + } + MetadataType.INT16 -> { + // 2-byte signed integer (little-endian) + val bytes = ByteArray(2) + if (input.read(bytes) != 2) throw IOException("Unexpected EOF while reading int16 value.") + // Combine to 16-bit and interpret as signed + val i16 = ((bytes[1].toInt() and 0xFF) shl 8) or (bytes[0].toInt() and 0xFF) + MetadataValue.Int16(i16.toShort()) + } + MetadataType.UINT32 -> { + // 4-byte unsigned integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading uint32 value.") + // Combine four bytes into a 32-bit value (as Long to avoid overflow), then convert to UInt + val u32 = (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.UInt32(u32.toUInt()) + } + MetadataType.INT32 -> { + // 4-byte signed integer (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading int32 value.") + // Combine four bytes into a 32-bit signed int + val i32 = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + MetadataValue.Int32(i32) + } + MetadataType.FLOAT32 -> { + // 4-byte IEEE 754 float (little-endian) + val bytes = ByteArray(4) + if (input.read(bytes) != 4) throw IOException("Unexpected EOF while reading float32 value.") + // Assemble 4 bytes into a 32-bit int bit-pattern, then convert to Float + val bits = (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + val floatVal = Float.fromBits(bits) + MetadataValue.Float32(floatVal) + } + MetadataType.BOOL -> { + // 1-byte boolean (0 = false, 1 = true) + val byteVal = input.read() + if (byteVal == -1) throw IOException("Unexpected EOF while reading boolean value.") + if (byteVal != 0 && byteVal != 1) { + throw IOException("Invalid boolean value: $byteVal (must be 0 or 1).") + } + MetadataValue.Bool(byteVal != 0) + } + MetadataType.STRING -> { + // UTF-8 string (length-prefixed with 8-byte length) + val str = readString(input) + MetadataValue.StringVal(str) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + val count = len.toInt() + + if (arraySummariseThreshold >= 0 && count > arraySummariseThreshold) { + // fast‑forward without allocation + repeat(count) { skipValue(input, elemType) } + MetadataValue.StringVal("Array($elemType, $count items) /* summarised */") + } else { + val list = ArrayList<MetadataValue>(count) + repeat(count) { list += parseValue(input, elemType) } + MetadataValue.ArrayVal(elemType, list) + } + } + MetadataType.UINT64 -> { + // 8-byte unsigned integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading uint64 value.") + // Combine 8 bytes into an unsigned 64-bit (ULong). Use ULong for full 0 to 2^64-1 range. + val u64 = (bytes[7].toULong() and 0xFFuL shl 56) or + (bytes[6].toULong() and 0xFFuL shl 48) or + (bytes[5].toULong() and 0xFFuL shl 40) or + (bytes[4].toULong() and 0xFFuL shl 32) or + (bytes[3].toULong() and 0xFFuL shl 24) or + (bytes[2].toULong() and 0xFFuL shl 16) or + (bytes[1].toULong() and 0xFFuL shl 8) or + (bytes[0].toULong() and 0xFFuL) + MetadataValue.UInt64(u64) + } + MetadataType.INT64 -> { + // 8-byte signed integer (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading int64 value.") + // Combine 8 bytes into a signed 64-bit value (Long) + val i64 = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + MetadataValue.Int64(i64) + } + MetadataType.FLOAT64 -> { + // 8-byte IEEE 754 double (little-endian) + val bytes = ByteArray(8) + if (input.read(bytes) != 8) throw IOException("Unexpected EOF while reading float64 value.") + // Assemble 8 bytes into a 64-bit bit-pattern, then convert to Double + val bits = (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + val doubleVal = Double.fromBits(bits) + MetadataValue.Float64(doubleVal) + } + } + + + private fun <T> T?.takeUnless(check: T.() -> Boolean): T? = + this?.takeIf { !it.check() } + + /** Helper: Skip a value in the stream without storing it (still maintains pointer). */ + private fun skipValue(input: InputStream, type: MetadataType) { + when (type) { + MetadataType.UINT8, MetadataType.INT8, MetadataType.BOOL -> input.skipFully(1) + MetadataType.UINT16, MetadataType.INT16 -> input.skipFully(2) + MetadataType.UINT32, MetadataType.INT32, MetadataType.FLOAT32 -> input.skipFully(4) + MetadataType.UINT64, MetadataType.INT64, MetadataType.FLOAT64 -> input.skipFully(8) + MetadataType.STRING -> { + val len = readLittleLong(input); input.skipFully(len) + } + MetadataType.ARRAY -> { + val elemType = MetadataType.fromCode(littleEndianBytesToInt(input.readNBytesExact(4))) + val len = readLittleLong(input) + repeat(len.toInt()) { skipValue(input, elemType) } // recursive skip + } + } + } + + /** Helper: Read an 8-byte little-endian unsigned value and return it as a signed Long (assuming it fits in 63 bits). */ + private fun readLittleLong(input: InputStream): Long { + val bytes = ByteArray(8) + input.readFully(bytes) + + // Combine 8 bytes into a 64-bit value (Little Endian). + // Note: If the value exceeds Long.MAX_VALUE (bit 63 is 1), this will produce a negative Long (two's complement). + // In our context (lengths/counts), such extremely large values are not expected. + return (bytes[7].toLong() and 0xFFL shl 56) or + (bytes[6].toLong() and 0xFFL shl 48) or + (bytes[5].toLong() and 0xFFL shl 40) or + (bytes[4].toLong() and 0xFFL shl 32) or + (bytes[3].toLong() and 0xFFL shl 24) or + (bytes[2].toLong() and 0xFFL shl 16) or + (bytes[1].toLong() and 0xFFL shl 8) or + (bytes[0].toLong() and 0xFFL) + } + + /** Helper: Read a GGUF string from the stream (8-byte length followed by UTF-8 bytes). */ + private fun readString(input: InputStream): String = + // Read 8-byte little-endian length (number of bytes in the string). + readLittleLong(input).let { len -> + if (len < 0 || len > Int.MAX_VALUE) throw IOException("String too long: $len") + + // Read the UTF-8 bytes of the given length. + ByteArray(len.toInt()).let { + if (it.isNotEmpty()) input.readFully(it) + String(it, Charsets.UTF_8) + } + } + + /** Helper: Convert a 4-byte little-endian byte array to a 32-bit integer. */ + private fun littleEndianBytesToInt(bytes: ByteArray): Int = + // Note: assumes bytes length is 4. + (bytes[3].toInt() and 0xFF shl 24) or + (bytes[2].toInt() and 0xFF shl 16) or + (bytes[1].toInt() and 0xFF shl 8) or + (bytes[0].toInt() and 0xFF) + + /** + * Robust skip that works the same on JDK 11 and Android’s desugared runtime. + * + * @param n Number of bytes to advance in the stream. + * @throws IOException on premature EOF. + */ + private fun InputStream.skipFully(n: Long) { + var remaining = n + val scratch = ByteArray(8192) // read‑and‑toss buffer + while (remaining > 0) { + val skipped = skip(remaining) + when { + skipped > 0 -> remaining -= skipped // normal fast path + skipped == 0L -> { + // fallback: read and discard + val read = read(scratch, 0, minOf(remaining, scratch.size.toLong()).toInt()) + if (read == -1) throw IOException("EOF while skipping $n bytes") + remaining -= read + } + else -> throw IOException("Skip returned negative value") + } + } + } + + /** + * Extension that keeps reading until the requested number of bytes are filled. + * Falls back to `read()` when `skip()` returns 0, which happens on some Android + * streams. + * + * @param buf Destination buffer. + * @param len Number of bytes to fill (defaults to `buf.size`). + * @throws IOException on premature EOF. + */ + private fun InputStream.readFully(buf: ByteArray, len: Int = buf.size) { + var off = 0 + while (off < len) { + val n = read(buf, off, len - off) + if (n == -1) throw IOException("EOF after $off of $len bytes") + off += n + } + } + + /** + * Read EXACTLY `n` bytes or throw – never returns a partially‑filled array. + * This is used for small fixed‑length reads (e.g. 4‑byte type codes). + * + * @throws IOException on premature EOF. + */ + private fun InputStream.readNBytesExact(n: Int) = ByteArray(n).also { + if (read(it) != n) throw IOException("Unexpected EOF") + } +} diff --git a/llama.cpp/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt b/llama.cpp/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt new file mode 100644 index 0000000..cbbb974 --- /dev/null +++ b/llama.cpp/examples/llama.android/lib/src/test/java/android/llama/cpp/ExampleUnitTest.kt @@ -0,0 +1,17 @@ +package android.llama.cpp + +import org.junit.Test + +import org.junit.Assert.* + +/** + * Example local unit test, which will execute on the development machine (host). + * + * See [testing documentation](http://d.android.com/tools/testing). + */ +class ExampleUnitTest { + @Test + fun addition_isCorrect() { + assertEquals(4, 2 + 2) + } +} diff --git a/llama.cpp/examples/llama.android/settings.gradle.kts b/llama.cpp/examples/llama.android/settings.gradle.kts new file mode 100644 index 0000000..74f4eb3 --- /dev/null +++ b/llama.cpp/examples/llama.android/settings.gradle.kts @@ -0,0 +1,18 @@ +pluginManagement { + repositories { + google() + mavenCentral() + gradlePluginPortal() + } +} +dependencyResolutionManagement { + repositoriesMode.set(RepositoriesMode.FAIL_ON_PROJECT_REPOS) + repositories { + mavenCentral() + google() + } +} + +rootProject.name = "AiChat" +include(":app") +include(":lib") |
