diff options
Diffstat (limited to 'llama.cpp/examples/llama.android/app/src/main/java/com/example/llama')
| -rw-r--r-- | llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt | 275 | ||||
| -rw-r--r-- | llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt | 51 |
2 files changed, 326 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt new file mode 100644 index 0000000..872ec2b --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt | |||
| @@ -0,0 +1,275 @@ | |||
| 1 | package com.example.llama | ||
| 2 | |||
| 3 | import android.net.Uri | ||
| 4 | import android.os.Bundle | ||
| 5 | import android.util.Log | ||
| 6 | import android.widget.EditText | ||
| 7 | import android.widget.TextView | ||
| 8 | import android.widget.Toast | ||
| 9 | import androidx.activity.addCallback | ||
| 10 | import androidx.activity.enableEdgeToEdge | ||
| 11 | import androidx.activity.result.contract.ActivityResultContracts | ||
| 12 | import androidx.appcompat.app.AppCompatActivity | ||
| 13 | import androidx.lifecycle.lifecycleScope | ||
| 14 | import androidx.recyclerview.widget.LinearLayoutManager | ||
| 15 | import androidx.recyclerview.widget.RecyclerView | ||
| 16 | import com.arm.aichat.AiChat | ||
| 17 | import com.arm.aichat.InferenceEngine | ||
| 18 | import com.arm.aichat.gguf.GgufMetadata | ||
| 19 | import com.arm.aichat.gguf.GgufMetadataReader | ||
| 20 | import com.google.android.material.floatingactionbutton.FloatingActionButton | ||
| 21 | import kotlinx.coroutines.Dispatchers | ||
| 22 | import kotlinx.coroutines.Job | ||
| 23 | import kotlinx.coroutines.flow.onCompletion | ||
| 24 | import kotlinx.coroutines.launch | ||
| 25 | import kotlinx.coroutines.withContext | ||
| 26 | import java.io.File | ||
| 27 | import java.io.FileOutputStream | ||
| 28 | import java.io.InputStream | ||
| 29 | import java.util.UUID | ||
| 30 | |||
| 31 | class MainActivity : AppCompatActivity() { | ||
| 32 | |||
| 33 | // Android views | ||
| 34 | private lateinit var ggufTv: TextView | ||
| 35 | private lateinit var messagesRv: RecyclerView | ||
| 36 | private lateinit var userInputEt: EditText | ||
| 37 | private lateinit var userActionFab: FloatingActionButton | ||
| 38 | |||
| 39 | // Arm AI Chat inference engine | ||
| 40 | private lateinit var engine: InferenceEngine | ||
| 41 | private var generationJob: Job? = null | ||
| 42 | |||
| 43 | // Conversation states | ||
| 44 | private var isModelReady = false | ||
| 45 | private val messages = mutableListOf<Message>() | ||
| 46 | private val lastAssistantMsg = StringBuilder() | ||
| 47 | private val messageAdapter = MessageAdapter(messages) | ||
| 48 | |||
| 49 | override fun onCreate(savedInstanceState: Bundle?) { | ||
| 50 | super.onCreate(savedInstanceState) | ||
| 51 | enableEdgeToEdge() | ||
| 52 | setContentView(R.layout.activity_main) | ||
| 53 | // View model boilerplate and state management is out of this basic sample's scope | ||
| 54 | onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") } | ||
| 55 | |||
| 56 | // Find views | ||
| 57 | ggufTv = findViewById(R.id.gguf) | ||
| 58 | messagesRv = findViewById(R.id.messages) | ||
| 59 | messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true } | ||
| 60 | messagesRv.adapter = messageAdapter | ||
| 61 | userInputEt = findViewById(R.id.user_input) | ||
| 62 | userActionFab = findViewById(R.id.fab) | ||
| 63 | |||
| 64 | // Arm AI Chat initialization | ||
| 65 | lifecycleScope.launch(Dispatchers.Default) { | ||
| 66 | engine = AiChat.getInferenceEngine(applicationContext) | ||
| 67 | } | ||
| 68 | |||
| 69 | // Upon CTA button tapped | ||
| 70 | userActionFab.setOnClickListener { | ||
| 71 | if (isModelReady) { | ||
| 72 | // If model is ready, validate input and send to engine | ||
| 73 | handleUserInput() | ||
| 74 | } else { | ||
| 75 | // Otherwise, prompt user to select a GGUF metadata on the device | ||
| 76 | getContent.launch(arrayOf("*/*")) | ||
| 77 | } | ||
| 78 | } | ||
| 79 | } | ||
| 80 | |||
| 81 | private val getContent = registerForActivityResult( | ||
| 82 | ActivityResultContracts.OpenDocument() | ||
| 83 | ) { uri -> | ||
| 84 | Log.i(TAG, "Selected file uri:\n $uri") | ||
| 85 | uri?.let { handleSelectedModel(it) } | ||
| 86 | } | ||
| 87 | |||
| 88 | /** | ||
| 89 | * Handles the file Uri from [getContent] result | ||
| 90 | */ | ||
| 91 | private fun handleSelectedModel(uri: Uri) { | ||
| 92 | // Update UI states | ||
| 93 | userActionFab.isEnabled = false | ||
| 94 | userInputEt.hint = "Parsing GGUF..." | ||
| 95 | ggufTv.text = "Parsing metadata from selected file \n$uri" | ||
| 96 | |||
| 97 | lifecycleScope.launch(Dispatchers.IO) { | ||
| 98 | // Parse GGUF metadata | ||
| 99 | Log.i(TAG, "Parsing GGUF metadata...") | ||
| 100 | contentResolver.openInputStream(uri)?.use { | ||
| 101 | GgufMetadataReader.create().readStructuredMetadata(it) | ||
| 102 | }?.let { metadata -> | ||
| 103 | // Update UI to show GGUF metadata to user | ||
| 104 | Log.i(TAG, "GGUF parsed: \n$metadata") | ||
| 105 | withContext(Dispatchers.Main) { | ||
| 106 | ggufTv.text = metadata.toString() | ||
| 107 | } | ||
| 108 | |||
| 109 | // Ensure the model file is available | ||
| 110 | val modelName = metadata.filename() + FILE_EXTENSION_GGUF | ||
| 111 | contentResolver.openInputStream(uri)?.use { input -> | ||
| 112 | ensureModelFile(modelName, input) | ||
| 113 | }?.let { modelFile -> | ||
| 114 | loadModel(modelName, modelFile) | ||
| 115 | |||
| 116 | withContext(Dispatchers.Main) { | ||
| 117 | isModelReady = true | ||
| 118 | userInputEt.hint = "Type and send a message!" | ||
| 119 | userInputEt.isEnabled = true | ||
| 120 | userActionFab.setImageResource(R.drawable.outline_send_24) | ||
| 121 | userActionFab.isEnabled = true | ||
| 122 | } | ||
| 123 | } | ||
| 124 | } | ||
| 125 | } | ||
| 126 | } | ||
| 127 | |||
| 128 | /** | ||
| 129 | * Prepare the model file within app's private storage | ||
| 130 | */ | ||
| 131 | private suspend fun ensureModelFile(modelName: String, input: InputStream) = | ||
| 132 | withContext(Dispatchers.IO) { | ||
| 133 | File(ensureModelsDirectory(), modelName).also { file -> | ||
| 134 | // Copy the file into local storage if not yet done | ||
| 135 | if (!file.exists()) { | ||
| 136 | Log.i(TAG, "Start copying file to $modelName") | ||
| 137 | withContext(Dispatchers.Main) { | ||
| 138 | userInputEt.hint = "Copying file..." | ||
| 139 | } | ||
| 140 | |||
| 141 | FileOutputStream(file).use { input.copyTo(it) } | ||
| 142 | Log.i(TAG, "Finished copying file to $modelName") | ||
| 143 | } else { | ||
| 144 | Log.i(TAG, "File already exists $modelName") | ||
| 145 | } | ||
| 146 | } | ||
| 147 | } | ||
| 148 | |||
| 149 | /** | ||
| 150 | * Load the model file from the app private storage | ||
| 151 | */ | ||
| 152 | private suspend fun loadModel(modelName: String, modelFile: File) = | ||
| 153 | withContext(Dispatchers.IO) { | ||
| 154 | Log.i(TAG, "Loading model $modelName") | ||
| 155 | withContext(Dispatchers.Main) { | ||
| 156 | userInputEt.hint = "Loading model..." | ||
| 157 | } | ||
| 158 | engine.loadModel(modelFile.path) | ||
| 159 | } | ||
| 160 | |||
| 161 | /** | ||
| 162 | * Validate and send the user message into [InferenceEngine] | ||
| 163 | */ | ||
| 164 | private fun handleUserInput() { | ||
| 165 | userInputEt.text.toString().also { userMsg -> | ||
| 166 | if (userMsg.isEmpty()) { | ||
| 167 | Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show() | ||
| 168 | } else { | ||
| 169 | userInputEt.text = null | ||
| 170 | userInputEt.isEnabled = false | ||
| 171 | userActionFab.isEnabled = false | ||
| 172 | |||
| 173 | // Update message states | ||
| 174 | messages.add(Message(UUID.randomUUID().toString(), userMsg, true)) | ||
| 175 | lastAssistantMsg.clear() | ||
| 176 | messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false)) | ||
| 177 | |||
| 178 | generationJob = lifecycleScope.launch(Dispatchers.Default) { | ||
| 179 | engine.sendUserPrompt(userMsg) | ||
| 180 | .onCompletion { | ||
| 181 | withContext(Dispatchers.Main) { | ||
| 182 | userInputEt.isEnabled = true | ||
| 183 | userActionFab.isEnabled = true | ||
| 184 | } | ||
| 185 | }.collect { token -> | ||
| 186 | withContext(Dispatchers.Main) { | ||
| 187 | val messageCount = messages.size | ||
| 188 | check(messageCount > 0 && !messages[messageCount - 1].isUser) | ||
| 189 | |||
| 190 | messages.removeAt(messageCount - 1).copy( | ||
| 191 | content = lastAssistantMsg.append(token).toString() | ||
| 192 | ).let { messages.add(it) } | ||
| 193 | |||
| 194 | messageAdapter.notifyItemChanged(messages.size - 1) | ||
| 195 | } | ||
| 196 | } | ||
| 197 | } | ||
| 198 | } | ||
| 199 | } | ||
| 200 | } | ||
| 201 | |||
| 202 | /** | ||
| 203 | * Run a benchmark with the model file | ||
| 204 | */ | ||
| 205 | @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers") | ||
| 206 | private suspend fun runBenchmark(modelName: String, modelFile: File) = | ||
| 207 | withContext(Dispatchers.Default) { | ||
| 208 | Log.i(TAG, "Starts benchmarking $modelName") | ||
| 209 | withContext(Dispatchers.Main) { | ||
| 210 | userInputEt.hint = "Running benchmark..." | ||
| 211 | } | ||
| 212 | engine.bench( | ||
| 213 | pp=BENCH_PROMPT_PROCESSING_TOKENS, | ||
| 214 | tg=BENCH_TOKEN_GENERATION_TOKENS, | ||
| 215 | pl=BENCH_SEQUENCE, | ||
| 216 | nr=BENCH_REPETITION | ||
| 217 | ).let { result -> | ||
| 218 | messages.add(Message(UUID.randomUUID().toString(), result, false)) | ||
| 219 | withContext(Dispatchers.Main) { | ||
| 220 | messageAdapter.notifyItemChanged(messages.size - 1) | ||
| 221 | } | ||
| 222 | } | ||
| 223 | } | ||
| 224 | |||
| 225 | /** | ||
| 226 | * Create the `models` directory if not exist. | ||
| 227 | */ | ||
| 228 | private fun ensureModelsDirectory() = | ||
| 229 | File(filesDir, DIRECTORY_MODELS).also { | ||
| 230 | if (it.exists() && !it.isDirectory) { it.delete() } | ||
| 231 | if (!it.exists()) { it.mkdir() } | ||
| 232 | } | ||
| 233 | |||
| 234 | override fun onStop() { | ||
| 235 | generationJob?.cancel() | ||
| 236 | super.onStop() | ||
| 237 | } | ||
| 238 | |||
| 239 | override fun onDestroy() { | ||
| 240 | engine.destroy() | ||
| 241 | super.onDestroy() | ||
| 242 | } | ||
| 243 | |||
| 244 | companion object { | ||
| 245 | private val TAG = MainActivity::class.java.simpleName | ||
| 246 | |||
| 247 | private const val DIRECTORY_MODELS = "models" | ||
| 248 | private const val FILE_EXTENSION_GGUF = ".gguf" | ||
| 249 | |||
| 250 | private const val BENCH_PROMPT_PROCESSING_TOKENS = 512 | ||
| 251 | private const val BENCH_TOKEN_GENERATION_TOKENS = 128 | ||
| 252 | private const val BENCH_SEQUENCE = 1 | ||
| 253 | private const val BENCH_REPETITION = 3 | ||
| 254 | } | ||
| 255 | } | ||
| 256 | |||
| 257 | fun GgufMetadata.filename() = when { | ||
| 258 | basic.name != null -> { | ||
| 259 | basic.name?.let { name -> | ||
| 260 | basic.sizeLabel?.let { size -> | ||
| 261 | "$name-$size" | ||
| 262 | } ?: name | ||
| 263 | } | ||
| 264 | } | ||
| 265 | architecture?.architecture != null -> { | ||
| 266 | architecture?.architecture?.let { arch -> | ||
| 267 | basic.uuid?.let { uuid -> | ||
| 268 | "$arch-$uuid" | ||
| 269 | } ?: "$arch-${System.currentTimeMillis()}" | ||
| 270 | } | ||
| 271 | } | ||
| 272 | else -> { | ||
| 273 | "model-${System.currentTimeMillis().toHexString()}" | ||
| 274 | } | ||
| 275 | } | ||
diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt new file mode 100644 index 0000000..0439f96 --- /dev/null +++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt | |||
| @@ -0,0 +1,51 @@ | |||
| 1 | package com.example.llama | ||
| 2 | |||
| 3 | import android.view.LayoutInflater | ||
| 4 | import android.view.View | ||
| 5 | import android.view.ViewGroup | ||
| 6 | import android.widget.TextView | ||
| 7 | import androidx.recyclerview.widget.RecyclerView | ||
| 8 | |||
| 9 | data class Message( | ||
| 10 | val id: String, | ||
| 11 | val content: String, | ||
| 12 | val isUser: Boolean | ||
| 13 | ) | ||
| 14 | |||
| 15 | class MessageAdapter( | ||
| 16 | private val messages: List<Message> | ||
| 17 | ) : RecyclerView.Adapter<RecyclerView.ViewHolder>() { | ||
| 18 | |||
| 19 | companion object { | ||
| 20 | private const val VIEW_TYPE_USER = 1 | ||
| 21 | private const val VIEW_TYPE_ASSISTANT = 2 | ||
| 22 | } | ||
| 23 | |||
| 24 | override fun getItemViewType(position: Int): Int { | ||
| 25 | return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT | ||
| 26 | } | ||
| 27 | |||
| 28 | override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder { | ||
| 29 | val layoutInflater = LayoutInflater.from(parent.context) | ||
| 30 | return if (viewType == VIEW_TYPE_USER) { | ||
| 31 | val view = layoutInflater.inflate(R.layout.item_message_user, parent, false) | ||
| 32 | UserMessageViewHolder(view) | ||
| 33 | } else { | ||
| 34 | val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false) | ||
| 35 | AssistantMessageViewHolder(view) | ||
| 36 | } | ||
| 37 | } | ||
| 38 | |||
| 39 | override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) { | ||
| 40 | val message = messages[position] | ||
| 41 | if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) { | ||
| 42 | val textView = holder.itemView.findViewById<TextView>(R.id.msg_content) | ||
| 43 | textView.text = message.content | ||
| 44 | } | ||
| 45 | } | ||
| 46 | |||
| 47 | override fun getItemCount(): Int = messages.size | ||
| 48 | |||
| 49 | class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) | ||
| 50 | class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view) | ||
| 51 | } | ||
