blob: 26c1668724662f9fb57f93f725e36d5678a4aced (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
|
package com.arm.aichat
import com.arm.aichat.InferenceEngine.State
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.StateFlow
/**
* Interface defining the core LLM inference operations.
*/
interface InferenceEngine {
/**
* Current state of the inference engine
*/
val state: StateFlow<State>
/**
* Load a model from the given path.
*
* @throws UnsupportedArchitectureException if model architecture not supported
*/
suspend fun loadModel(pathToModel: String)
/**
* Sends a system prompt to the loaded model
*/
suspend fun setSystemPrompt(systemPrompt: String)
/**
* Sends a user prompt to the loaded model and returns a Flow of generated tokens.
*/
fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
/**
* Runs a benchmark with the specified parameters.
*/
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
/**
* Unloads the currently loaded model.
*/
fun cleanUp()
/**
* Cleans up resources when the engine is no longer needed.
*/
fun destroy()
/**
* States of the inference engine
*/
sealed class State {
object Uninitialized : State()
object Initializing : State()
object Initialized : State()
object LoadingModel : State()
object UnloadingModel : State()
object ModelReady : State()
object Benchmarking : State()
object ProcessingSystemPrompt : State()
object ProcessingUserPrompt : State()
object Generating : State()
data class Error(val exception: Exception) : State()
}
companion object {
const val DEFAULT_PREDICT_LENGTH = 1024
}
}
val State.isUninterruptible
get() = this is State.Initializing ||
this is State.LoadingModel ||
this is State.UnloadingModel ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt
val State.isModelLoaded: Boolean
get() = this is State.ModelReady ||
this is State.Benchmarking ||
this is State.ProcessingSystemPrompt ||
this is State.ProcessingUserPrompt ||
this is State.Generating
class UnsupportedArchitectureException : Exception()
|