summaryrefslogtreecommitdiff
path: root/llama.cpp/examples/llama.android/app/src/main/java/com
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/examples/llama.android/app/src/main/java/com')
-rw-r--r--llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt275
-rw-r--r--llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt51
2 files changed, 326 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
new file mode 100644
index 0000000..872ec2b
--- /dev/null
+++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
@@ -0,0 +1,275 @@
+package com.example.llama
+
+import android.net.Uri
+import android.os.Bundle
+import android.util.Log
+import android.widget.EditText
+import android.widget.TextView
+import android.widget.Toast
+import androidx.activity.addCallback
+import androidx.activity.enableEdgeToEdge
+import androidx.activity.result.contract.ActivityResultContracts
+import androidx.appcompat.app.AppCompatActivity
+import androidx.lifecycle.lifecycleScope
+import androidx.recyclerview.widget.LinearLayoutManager
+import androidx.recyclerview.widget.RecyclerView
+import com.arm.aichat.AiChat
+import com.arm.aichat.InferenceEngine
+import com.arm.aichat.gguf.GgufMetadata
+import com.arm.aichat.gguf.GgufMetadataReader
+import com.google.android.material.floatingactionbutton.FloatingActionButton
+import kotlinx.coroutines.Dispatchers
+import kotlinx.coroutines.Job
+import kotlinx.coroutines.flow.onCompletion
+import kotlinx.coroutines.launch
+import kotlinx.coroutines.withContext
+import java.io.File
+import java.io.FileOutputStream
+import java.io.InputStream
+import java.util.UUID
+
+class MainActivity : AppCompatActivity() {
+
+ // Android views
+ private lateinit var ggufTv: TextView
+ private lateinit var messagesRv: RecyclerView
+ private lateinit var userInputEt: EditText
+ private lateinit var userActionFab: FloatingActionButton
+
+ // Arm AI Chat inference engine
+ private lateinit var engine: InferenceEngine
+ private var generationJob: Job? = null
+
+ // Conversation states
+ private var isModelReady = false
+ private val messages = mutableListOf<Message>()
+ private val lastAssistantMsg = StringBuilder()
+ private val messageAdapter = MessageAdapter(messages)
+
+ override fun onCreate(savedInstanceState: Bundle?) {
+ super.onCreate(savedInstanceState)
+ enableEdgeToEdge()
+ setContentView(R.layout.activity_main)
+ // View model boilerplate and state management is out of this basic sample's scope
+ onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
+
+ // Find views
+ ggufTv = findViewById(R.id.gguf)
+ messagesRv = findViewById(R.id.messages)
+ messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
+ messagesRv.adapter = messageAdapter
+ userInputEt = findViewById(R.id.user_input)
+ userActionFab = findViewById(R.id.fab)
+
+ // Arm AI Chat initialization
+ lifecycleScope.launch(Dispatchers.Default) {
+ engine = AiChat.getInferenceEngine(applicationContext)
+ }
+
+ // Upon CTA button tapped
+ userActionFab.setOnClickListener {
+ if (isModelReady) {
+ // If model is ready, validate input and send to engine
+ handleUserInput()
+ } else {
+ // Otherwise, prompt user to select a GGUF metadata on the device
+ getContent.launch(arrayOf("*/*"))
+ }
+ }
+ }
+
+ private val getContent = registerForActivityResult(
+ ActivityResultContracts.OpenDocument()
+ ) { uri ->
+ Log.i(TAG, "Selected file uri:\n $uri")
+ uri?.let { handleSelectedModel(it) }
+ }
+
+ /**
+ * Handles the file Uri from [getContent] result
+ */
+ private fun handleSelectedModel(uri: Uri) {
+ // Update UI states
+ userActionFab.isEnabled = false
+ userInputEt.hint = "Parsing GGUF..."
+ ggufTv.text = "Parsing metadata from selected file \n$uri"
+
+ lifecycleScope.launch(Dispatchers.IO) {
+ // Parse GGUF metadata
+ Log.i(TAG, "Parsing GGUF metadata...")
+ contentResolver.openInputStream(uri)?.use {
+ GgufMetadataReader.create().readStructuredMetadata(it)
+ }?.let { metadata ->
+ // Update UI to show GGUF metadata to user
+ Log.i(TAG, "GGUF parsed: \n$metadata")
+ withContext(Dispatchers.Main) {
+ ggufTv.text = metadata.toString()
+ }
+
+ // Ensure the model file is available
+ val modelName = metadata.filename() + FILE_EXTENSION_GGUF
+ contentResolver.openInputStream(uri)?.use { input ->
+ ensureModelFile(modelName, input)
+ }?.let { modelFile ->
+ loadModel(modelName, modelFile)
+
+ withContext(Dispatchers.Main) {
+ isModelReady = true
+ userInputEt.hint = "Type and send a message!"
+ userInputEt.isEnabled = true
+ userActionFab.setImageResource(R.drawable.outline_send_24)
+ userActionFab.isEnabled = true
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Prepare the model file within app's private storage
+ */
+ private suspend fun ensureModelFile(modelName: String, input: InputStream) =
+ withContext(Dispatchers.IO) {
+ File(ensureModelsDirectory(), modelName).also { file ->
+ // Copy the file into local storage if not yet done
+ if (!file.exists()) {
+ Log.i(TAG, "Start copying file to $modelName")
+ withContext(Dispatchers.Main) {
+ userInputEt.hint = "Copying file..."
+ }
+
+ FileOutputStream(file).use { input.copyTo(it) }
+ Log.i(TAG, "Finished copying file to $modelName")
+ } else {
+ Log.i(TAG, "File already exists $modelName")
+ }
+ }
+ }
+
+ /**
+ * Load the model file from the app private storage
+ */
+ private suspend fun loadModel(modelName: String, modelFile: File) =
+ withContext(Dispatchers.IO) {
+ Log.i(TAG, "Loading model $modelName")
+ withContext(Dispatchers.Main) {
+ userInputEt.hint = "Loading model..."
+ }
+ engine.loadModel(modelFile.path)
+ }
+
+ /**
+ * Validate and send the user message into [InferenceEngine]
+ */
+ private fun handleUserInput() {
+ userInputEt.text.toString().also { userMsg ->
+ if (userMsg.isEmpty()) {
+ Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
+ } else {
+ userInputEt.text = null
+ userInputEt.isEnabled = false
+ userActionFab.isEnabled = false
+
+ // Update message states
+ messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
+ lastAssistantMsg.clear()
+ messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
+
+ generationJob = lifecycleScope.launch(Dispatchers.Default) {
+ engine.sendUserPrompt(userMsg)
+ .onCompletion {
+ withContext(Dispatchers.Main) {
+ userInputEt.isEnabled = true
+ userActionFab.isEnabled = true
+ }
+ }.collect { token ->
+ withContext(Dispatchers.Main) {
+ val messageCount = messages.size
+ check(messageCount > 0 && !messages[messageCount - 1].isUser)
+
+ messages.removeAt(messageCount - 1).copy(
+ content = lastAssistantMsg.append(token).toString()
+ ).let { messages.add(it) }
+
+ messageAdapter.notifyItemChanged(messages.size - 1)
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Run a benchmark with the model file
+ */
+ @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
+ private suspend fun runBenchmark(modelName: String, modelFile: File) =
+ withContext(Dispatchers.Default) {
+ Log.i(TAG, "Starts benchmarking $modelName")
+ withContext(Dispatchers.Main) {
+ userInputEt.hint = "Running benchmark..."
+ }
+ engine.bench(
+ pp=BENCH_PROMPT_PROCESSING_TOKENS,
+ tg=BENCH_TOKEN_GENERATION_TOKENS,
+ pl=BENCH_SEQUENCE,
+ nr=BENCH_REPETITION
+ ).let { result ->
+ messages.add(Message(UUID.randomUUID().toString(), result, false))
+ withContext(Dispatchers.Main) {
+ messageAdapter.notifyItemChanged(messages.size - 1)
+ }
+ }
+ }
+
+ /**
+ * Create the `models` directory if not exist.
+ */
+ private fun ensureModelsDirectory() =
+ File(filesDir, DIRECTORY_MODELS).also {
+ if (it.exists() && !it.isDirectory) { it.delete() }
+ if (!it.exists()) { it.mkdir() }
+ }
+
+ override fun onStop() {
+ generationJob?.cancel()
+ super.onStop()
+ }
+
+ override fun onDestroy() {
+ engine.destroy()
+ super.onDestroy()
+ }
+
+ companion object {
+ private val TAG = MainActivity::class.java.simpleName
+
+ private const val DIRECTORY_MODELS = "models"
+ private const val FILE_EXTENSION_GGUF = ".gguf"
+
+ private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
+ private const val BENCH_TOKEN_GENERATION_TOKENS = 128
+ private const val BENCH_SEQUENCE = 1
+ private const val BENCH_REPETITION = 3
+ }
+}
+
+fun GgufMetadata.filename() = when {
+ basic.name != null -> {
+ basic.name?.let { name ->
+ basic.sizeLabel?.let { size ->
+ "$name-$size"
+ } ?: name
+ }
+ }
+ architecture?.architecture != null -> {
+ architecture?.architecture?.let { arch ->
+ basic.uuid?.let { uuid ->
+ "$arch-$uuid"
+ } ?: "$arch-${System.currentTimeMillis()}"
+ }
+ }
+ else -> {
+ "model-${System.currentTimeMillis().toHexString()}"
+ }
+}
diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt
new file mode 100644
index 0000000..0439f96
--- /dev/null
+++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt
@@ -0,0 +1,51 @@
+package com.example.llama
+
+import android.view.LayoutInflater
+import android.view.View
+import android.view.ViewGroup
+import android.widget.TextView
+import androidx.recyclerview.widget.RecyclerView
+
+data class Message(
+ val id: String,
+ val content: String,
+ val isUser: Boolean
+)
+
+class MessageAdapter(
+ private val messages: List<Message>
+) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
+
+ companion object {
+ private const val VIEW_TYPE_USER = 1
+ private const val VIEW_TYPE_ASSISTANT = 2
+ }
+
+ override fun getItemViewType(position: Int): Int {
+ return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
+ }
+
+ override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
+ val layoutInflater = LayoutInflater.from(parent.context)
+ return if (viewType == VIEW_TYPE_USER) {
+ val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
+ UserMessageViewHolder(view)
+ } else {
+ val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
+ AssistantMessageViewHolder(view)
+ }
+ }
+
+ override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
+ val message = messages[position]
+ if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
+ val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
+ textView.text = message.content
+ }
+ }
+
+ override fun getItemCount(): Int = messages.size
+
+ class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
+ class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
+}