summaryrefslogtreecommitdiff
path: root/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/examples/llama.android/app/src/main/java/com/example/llama')
-rw-r--r--llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt275
-rw-r--r--llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt51
2 files changed, 326 insertions, 0 deletions
diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
new file mode 100644
index 0000000..872ec2b
--- /dev/null
+++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MainActivity.kt
@@ -0,0 +1,275 @@
1package com.example.llama
2
3import android.net.Uri
4import android.os.Bundle
5import android.util.Log
6import android.widget.EditText
7import android.widget.TextView
8import android.widget.Toast
9import androidx.activity.addCallback
10import androidx.activity.enableEdgeToEdge
11import androidx.activity.result.contract.ActivityResultContracts
12import androidx.appcompat.app.AppCompatActivity
13import androidx.lifecycle.lifecycleScope
14import androidx.recyclerview.widget.LinearLayoutManager
15import androidx.recyclerview.widget.RecyclerView
16import com.arm.aichat.AiChat
17import com.arm.aichat.InferenceEngine
18import com.arm.aichat.gguf.GgufMetadata
19import com.arm.aichat.gguf.GgufMetadataReader
20import com.google.android.material.floatingactionbutton.FloatingActionButton
21import kotlinx.coroutines.Dispatchers
22import kotlinx.coroutines.Job
23import kotlinx.coroutines.flow.onCompletion
24import kotlinx.coroutines.launch
25import kotlinx.coroutines.withContext
26import java.io.File
27import java.io.FileOutputStream
28import java.io.InputStream
29import java.util.UUID
30
31class MainActivity : AppCompatActivity() {
32
33 // Android views
34 private lateinit var ggufTv: TextView
35 private lateinit var messagesRv: RecyclerView
36 private lateinit var userInputEt: EditText
37 private lateinit var userActionFab: FloatingActionButton
38
39 // Arm AI Chat inference engine
40 private lateinit var engine: InferenceEngine
41 private var generationJob: Job? = null
42
43 // Conversation states
44 private var isModelReady = false
45 private val messages = mutableListOf<Message>()
46 private val lastAssistantMsg = StringBuilder()
47 private val messageAdapter = MessageAdapter(messages)
48
49 override fun onCreate(savedInstanceState: Bundle?) {
50 super.onCreate(savedInstanceState)
51 enableEdgeToEdge()
52 setContentView(R.layout.activity_main)
53 // View model boilerplate and state management is out of this basic sample's scope
54 onBackPressedDispatcher.addCallback { Log.w(TAG, "Ignore back press for simplicity") }
55
56 // Find views
57 ggufTv = findViewById(R.id.gguf)
58 messagesRv = findViewById(R.id.messages)
59 messagesRv.layoutManager = LinearLayoutManager(this).apply { stackFromEnd = true }
60 messagesRv.adapter = messageAdapter
61 userInputEt = findViewById(R.id.user_input)
62 userActionFab = findViewById(R.id.fab)
63
64 // Arm AI Chat initialization
65 lifecycleScope.launch(Dispatchers.Default) {
66 engine = AiChat.getInferenceEngine(applicationContext)
67 }
68
69 // Upon CTA button tapped
70 userActionFab.setOnClickListener {
71 if (isModelReady) {
72 // If model is ready, validate input and send to engine
73 handleUserInput()
74 } else {
75 // Otherwise, prompt user to select a GGUF metadata on the device
76 getContent.launch(arrayOf("*/*"))
77 }
78 }
79 }
80
81 private val getContent = registerForActivityResult(
82 ActivityResultContracts.OpenDocument()
83 ) { uri ->
84 Log.i(TAG, "Selected file uri:\n $uri")
85 uri?.let { handleSelectedModel(it) }
86 }
87
88 /**
89 * Handles the file Uri from [getContent] result
90 */
91 private fun handleSelectedModel(uri: Uri) {
92 // Update UI states
93 userActionFab.isEnabled = false
94 userInputEt.hint = "Parsing GGUF..."
95 ggufTv.text = "Parsing metadata from selected file \n$uri"
96
97 lifecycleScope.launch(Dispatchers.IO) {
98 // Parse GGUF metadata
99 Log.i(TAG, "Parsing GGUF metadata...")
100 contentResolver.openInputStream(uri)?.use {
101 GgufMetadataReader.create().readStructuredMetadata(it)
102 }?.let { metadata ->
103 // Update UI to show GGUF metadata to user
104 Log.i(TAG, "GGUF parsed: \n$metadata")
105 withContext(Dispatchers.Main) {
106 ggufTv.text = metadata.toString()
107 }
108
109 // Ensure the model file is available
110 val modelName = metadata.filename() + FILE_EXTENSION_GGUF
111 contentResolver.openInputStream(uri)?.use { input ->
112 ensureModelFile(modelName, input)
113 }?.let { modelFile ->
114 loadModel(modelName, modelFile)
115
116 withContext(Dispatchers.Main) {
117 isModelReady = true
118 userInputEt.hint = "Type and send a message!"
119 userInputEt.isEnabled = true
120 userActionFab.setImageResource(R.drawable.outline_send_24)
121 userActionFab.isEnabled = true
122 }
123 }
124 }
125 }
126 }
127
128 /**
129 * Prepare the model file within app's private storage
130 */
131 private suspend fun ensureModelFile(modelName: String, input: InputStream) =
132 withContext(Dispatchers.IO) {
133 File(ensureModelsDirectory(), modelName).also { file ->
134 // Copy the file into local storage if not yet done
135 if (!file.exists()) {
136 Log.i(TAG, "Start copying file to $modelName")
137 withContext(Dispatchers.Main) {
138 userInputEt.hint = "Copying file..."
139 }
140
141 FileOutputStream(file).use { input.copyTo(it) }
142 Log.i(TAG, "Finished copying file to $modelName")
143 } else {
144 Log.i(TAG, "File already exists $modelName")
145 }
146 }
147 }
148
149 /**
150 * Load the model file from the app private storage
151 */
152 private suspend fun loadModel(modelName: String, modelFile: File) =
153 withContext(Dispatchers.IO) {
154 Log.i(TAG, "Loading model $modelName")
155 withContext(Dispatchers.Main) {
156 userInputEt.hint = "Loading model..."
157 }
158 engine.loadModel(modelFile.path)
159 }
160
161 /**
162 * Validate and send the user message into [InferenceEngine]
163 */
164 private fun handleUserInput() {
165 userInputEt.text.toString().also { userMsg ->
166 if (userMsg.isEmpty()) {
167 Toast.makeText(this, "Input message is empty!", Toast.LENGTH_SHORT).show()
168 } else {
169 userInputEt.text = null
170 userInputEt.isEnabled = false
171 userActionFab.isEnabled = false
172
173 // Update message states
174 messages.add(Message(UUID.randomUUID().toString(), userMsg, true))
175 lastAssistantMsg.clear()
176 messages.add(Message(UUID.randomUUID().toString(), lastAssistantMsg.toString(), false))
177
178 generationJob = lifecycleScope.launch(Dispatchers.Default) {
179 engine.sendUserPrompt(userMsg)
180 .onCompletion {
181 withContext(Dispatchers.Main) {
182 userInputEt.isEnabled = true
183 userActionFab.isEnabled = true
184 }
185 }.collect { token ->
186 withContext(Dispatchers.Main) {
187 val messageCount = messages.size
188 check(messageCount > 0 && !messages[messageCount - 1].isUser)
189
190 messages.removeAt(messageCount - 1).copy(
191 content = lastAssistantMsg.append(token).toString()
192 ).let { messages.add(it) }
193
194 messageAdapter.notifyItemChanged(messages.size - 1)
195 }
196 }
197 }
198 }
199 }
200 }
201
202 /**
203 * Run a benchmark with the model file
204 */
205 @Deprecated("This benchmark doesn't accurately indicate GUI performance expected by app developers")
206 private suspend fun runBenchmark(modelName: String, modelFile: File) =
207 withContext(Dispatchers.Default) {
208 Log.i(TAG, "Starts benchmarking $modelName")
209 withContext(Dispatchers.Main) {
210 userInputEt.hint = "Running benchmark..."
211 }
212 engine.bench(
213 pp=BENCH_PROMPT_PROCESSING_TOKENS,
214 tg=BENCH_TOKEN_GENERATION_TOKENS,
215 pl=BENCH_SEQUENCE,
216 nr=BENCH_REPETITION
217 ).let { result ->
218 messages.add(Message(UUID.randomUUID().toString(), result, false))
219 withContext(Dispatchers.Main) {
220 messageAdapter.notifyItemChanged(messages.size - 1)
221 }
222 }
223 }
224
225 /**
226 * Create the `models` directory if not exist.
227 */
228 private fun ensureModelsDirectory() =
229 File(filesDir, DIRECTORY_MODELS).also {
230 if (it.exists() && !it.isDirectory) { it.delete() }
231 if (!it.exists()) { it.mkdir() }
232 }
233
234 override fun onStop() {
235 generationJob?.cancel()
236 super.onStop()
237 }
238
239 override fun onDestroy() {
240 engine.destroy()
241 super.onDestroy()
242 }
243
244 companion object {
245 private val TAG = MainActivity::class.java.simpleName
246
247 private const val DIRECTORY_MODELS = "models"
248 private const val FILE_EXTENSION_GGUF = ".gguf"
249
250 private const val BENCH_PROMPT_PROCESSING_TOKENS = 512
251 private const val BENCH_TOKEN_GENERATION_TOKENS = 128
252 private const val BENCH_SEQUENCE = 1
253 private const val BENCH_REPETITION = 3
254 }
255}
256
257fun GgufMetadata.filename() = when {
258 basic.name != null -> {
259 basic.name?.let { name ->
260 basic.sizeLabel?.let { size ->
261 "$name-$size"
262 } ?: name
263 }
264 }
265 architecture?.architecture != null -> {
266 architecture?.architecture?.let { arch ->
267 basic.uuid?.let { uuid ->
268 "$arch-$uuid"
269 } ?: "$arch-${System.currentTimeMillis()}"
270 }
271 }
272 else -> {
273 "model-${System.currentTimeMillis().toHexString()}"
274 }
275}
diff --git a/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt
new file mode 100644
index 0000000..0439f96
--- /dev/null
+++ b/llama.cpp/examples/llama.android/app/src/main/java/com/example/llama/MessageAdapter.kt
@@ -0,0 +1,51 @@
1package com.example.llama
2
3import android.view.LayoutInflater
4import android.view.View
5import android.view.ViewGroup
6import android.widget.TextView
7import androidx.recyclerview.widget.RecyclerView
8
9data class Message(
10 val id: String,
11 val content: String,
12 val isUser: Boolean
13)
14
15class MessageAdapter(
16 private val messages: List<Message>
17) : RecyclerView.Adapter<RecyclerView.ViewHolder>() {
18
19 companion object {
20 private const val VIEW_TYPE_USER = 1
21 private const val VIEW_TYPE_ASSISTANT = 2
22 }
23
24 override fun getItemViewType(position: Int): Int {
25 return if (messages[position].isUser) VIEW_TYPE_USER else VIEW_TYPE_ASSISTANT
26 }
27
28 override fun onCreateViewHolder(parent: ViewGroup, viewType: Int): RecyclerView.ViewHolder {
29 val layoutInflater = LayoutInflater.from(parent.context)
30 return if (viewType == VIEW_TYPE_USER) {
31 val view = layoutInflater.inflate(R.layout.item_message_user, parent, false)
32 UserMessageViewHolder(view)
33 } else {
34 val view = layoutInflater.inflate(R.layout.item_message_assistant, parent, false)
35 AssistantMessageViewHolder(view)
36 }
37 }
38
39 override fun onBindViewHolder(holder: RecyclerView.ViewHolder, position: Int) {
40 val message = messages[position]
41 if (holder is UserMessageViewHolder || holder is AssistantMessageViewHolder) {
42 val textView = holder.itemView.findViewById<TextView>(R.id.msg_content)
43 textView.text = message.content
44 }
45 }
46
47 override fun getItemCount(): Int = messages.size
48
49 class UserMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
50 class AssistantMessageViewHolder(view: View) : RecyclerView.ViewHolder(view)
51}