1package com.example.llama
  2
  3import android.net.Uri
  4import android.os.Bundle
  5import android.util.Log
  6import android.widget.EditText
  7import android.widget.TextView
  8import android.widget.Toast
  9import androidx.activity.addCallback
 10import androidx.activity.enableEdgeToEdge
 11import androidx.activity.result.contract.ActivityResultContracts
 12import androidx.appcompat.app.AppCompatActivity
 13import androidx.lifecycle.lifecycleScope
 14import androidx.recyclerview.widget.LinearLayoutManager
 15import androidx.recyclerview.widget.RecyclerView
 16import com.arm.aichat.AiChat
 17import com.arm.aichat.InferenceEngine
 18import com.arm.aichat.gguf.GgufMetadata
 19import com.arm.aichat.gguf.GgufMetadataReader
 20import com.google.android.material.floatingactionbutton.FloatingActionButton
 21import kotlinx.coroutines.Dispatchers
 22import kotlinx.coroutines.Job
 23import kotlinx.coroutines.flow.onCompletion
 24import kotlinx.coroutines.launch
 25import kotlinx.coroutines.withContext
 26import java.io.File
 27import java.io.FileOutputStream
 28import java.io.InputStream
 29import java.util.UUID
 30
 31class MainActivity : AppCompatActivity() {
 32
 33    // Android views
 34    private lateinit var ggufTv: TextView
 35    private lateinit var messagesRv: RecyclerView
 36    private lateinit var userInputEt: EditText
 37    private lateinit var userActionFab: FloatingActionButton
 38
 39    // Arm AI Chat inference engine
 40    private lateinit var engine: InferenceEngine
 41    private var generationJob: Job? = null
 42
 43    // Conversation states
 44    private var isModelReady = false
 45    private val messages = mutableListOf<Message>()
 46    private val lastAssistantMsg = StringBuilder()
 47    private val messageAdapter = MessageAdapter(messages)
 48
 49    override fun onCreate(savedInstanceState: Bundle?) {
 50        super.onCreate(savedInstanceState)
 51        enableEdgeToEdge()
 52        setContentView(R.layout.activity_main)
 53        // View model boilerplate and state management is out of this basic sample's scope
 54        onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
 55
 56        // Find views
 57        ggufTv = findViewById(R.id.gguf)
 58        messagesRv = findViewById(R.id.messages)
 59        messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
 60        messagesRv.adapter = messageAdapter
 61        userInputEt = findViewById(R.id.user_input)
 62        userActionFab = findViewById(R.id.fab)
 63
 64        // Arm AI Chat initialization
 65        lifecycleScope.launch(Dispatchers.Default) {
 66            engine = AiChat.getInferenceEngine(applicationContext)
 67        }
 68
 69        // Upon CTA button tapped
 70        userActionFab.setOnClickListener {
 71            if (isModelReady) {
 72                // If model is ready, validate input and send to engine
 73                handleUserInput()
 74            } else {
 75                // Otherwise, prompt user to select a GGUF metadata on the device
 76                getContent.launch(arrayOf("*/*"))
 77            }
 78        }
 79    }
 80
 81    private val getContent = registerForActivityResult(
 82        ActivityResultContracts.OpenDocument()
 83    ) { uri ->
 84        Log.i(TAG, "Selected file uri:\n $uri")
 85        uri?.let { handleSelectedModel(it) }
 86    }
 87
 88    /**
 89     * Handles the file Uri from [getContent] result
 90     */
 91    private fun handleSelectedModel(uri: Uri) {
 92        // Update UI states
 93        userActionFab.isEnabled = false
 94        userInputEt.hint = "Parsing GGUF..."
 95        ggufTv.text = "Parsing metadata from selected file \n$uri"
 96
 97        lifecycleScope.launch(Dispatchers.IO) {
 98            // Parse GGUF metadata
 99            Log.i(TAG, "Parsing GGUF metadata...")
100            contentResolver.openInputStream(uri)?.use {
101                GgufMetadataReader.create().readStructuredMetadata(it)
102            }?.let { metadata ->
103                // Update UI to show GGUF metadata to user
104                Log.i(TAG, "GGUF parsed: \n$metadata")
105                withContext(Dispatchers.Main) {
106                    ggufTv.text = metadata.toString()
107                }
108
109                // Ensure the model file is available
110                val modelName = metadata.filename() + FILE_EXTENSION_GGUF
111                contentResolver.openInputStream(uri)?.use { input ->
112                    ensureModelFile(modelName, input)
113                }?.let { modelFile ->
114                    loadModel(modelName, modelFile)
115
116                    withContext(Dispatchers.Main) {
117                        isModelReady = true
118                        userInputEt.hint = "Type and send a message!"
119                        userInputEt.isEnabled = true
120                        userActionFab.setImageResource(R.drawable.outline_send_24)
121                        userActionFab.isEnabled = true
122                    }
123                }
124            }
125        }
126    }
127
128    /**
129     * Prepare the model file within app's private storage
130     */
131    private suspend fun ensureModelFile(modelName: String, input: InputStream) =
132        withContext(Dispatchers.IO) {
133            File(ensureModelsDirectory(), modelName).also { file ->
134                // Copy the file into local storage if not yet done
135                if (!file.exists()) {
136                    Log.i(TAG, "Start copying file to $modelName")
137                    withContext(Dispatchers.Main) {
138                        userInputEt.hint = "Copying file..."
139                    }
140
141                    FileOutputStream(file).use { input.copyTo(it) }
142                    Log.i(TAG, "Finished copying file to $modelName")
143                } else {
144                    Log.i(TAG, "File already exists $modelName")
145                }
146            }
147        }
148
149    /**
150     * Load the model file from the app private storage
151     */
152    private suspend fun loadModel(modelName: String, modelFile: File) =
153        withContext(Dispatchers.IO) {
154            Log.i(TAG, "Loading model $modelName")
155            withContext(Dispatchers.Main) {
156                userInputEt.hint = "Loading model..."
157            }
158            engine.loadModel(modelFile.path)
159        }
160
161    /**
162     * Validate and send the user message into [InferenceEngine]
163     */
164    private fun handleUserInput() {
165        userInputEt.text.toString().also { userMsg ->
166            if (userMsg.isEmpty()) {
167                Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
168            } else {
169                userInputEt.text = null
170                userInputEt.isEnabled = false
171                userActionFab.isEnabled = false
172
173                // Update message states
174                messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
175                lastAssistantMsg.clear()
176                messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
177
178                generationJob = lifecycleScope.launch(Dispatchers.Default) {
179                    engine.sendUserPrompt(userMsg)
180                        .onCompletion {
181                            withContext(Dispatchers.Main) {
182                                userInputEt.isEnabled = true
183                                userActionFab.isEnabled = true
184                            }
185                        }.collect { token ->
186                            withContext(Dispatchers.Main) {
187                                val messageCount = messages.size
188                                check(messageCount > 0 && !messages[messageCount - 1].isUser)
189
190                                messages.removeAt(messageCount - 1).copy(
191                                    content = lastAssistantMsg.append(token).toString()
192                                ).let { messages.add(it) }
193
194                                messageAdapter.notifyItemChanged(messages.size - 1)
195                            }
196                        }
197                }
198            }
199        }
200    }
201
202    /**
203     * Run a benchmark with the model file
204     */
205    @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
206    private suspend fun runBenchmark(modelName: String, modelFile: File) =
207        withContext(Dispatchers.Default) {
208            Log.i(TAG, "Starts benchmarking $modelName")
209            withContext(Dispatchers.Main) {
210                userInputEt.hint = "Running benchmark..."
211            }
212            engine.bench(
213                pp=BENCH_PROMPT_PROCESSING_TOKENS,
214                tg=BENCH_TOKEN_GENERATION_TOKENS,
215                pl=BENCH_SEQUENCE,
216                nr=BENCH_REPETITION
217            ).let { result ->
218                messages.add(Message(UUID.randomUUID().toString(), result, false))
219                withContext(Dispatchers.Main) {
220                    messageAdapter.notifyItemChanged(messages.size - 1)
221                }
222            }
223        }
224
225    /**
226     * Create the `models` directory if not exist.
227     */
228    private fun ensureModelsDirectory() =
229        File(filesDir, DIRECTORY_MODELS).also {
230            if (it.exists() && !it.isDirectory) { it.delete() }
231            if (!it.exists()) { it.mkdir() }
232        }
233
234    override fun onStop() {
235        generationJob?.cancel()
236        super.onStop()
237    }
238
239    override fun onDestroy() {
240        engine.destroy()
241        super.onDestroy()
242    }
243
244    companion object {
245        private val TAG = MainActivity::class.java.simpleName
246
247        private const val DIRECTORY_MODELS = "models"
248        private const val FILE_EXTENSION_GGUF = ".gguf"
249
250        private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
251        private const val BENCH_TOKEN_GENERATION_TOKENS = 128
252        private const val BENCH_SEQUENCE = 1
253        private const val BENCH_REPETITION = 3
254    }
255}
256
257fun GgufMetadata.filename() = when {
258    basic.name != null -> {
259        basic.name?.let { name ->
260            basic.sizeLabel?.let { size ->
261                "$name-$size"
262            } ?: name
263        }
264    }
265    architecture?.architecture != null -> {
266        architecture?.architecture?.let { arch ->
267            basic.uuid?.let { uuid ->
268                "$arch-$uuid"
269            } ?: "$arch-${System.currentTimeMillis()}"
270        }
271    }
272    else -> {
273        "model-${System.currentTimeMillis().toHexString()}"
274    }
275}