1package com.arm.aichat
2
3import com.arm.aichat.InferenceEngine.State
4import kotlinx.coroutines.flow.Flow
5import kotlinx.coroutines.flow.StateFlow
6
7/**
8 * Interface defining the core LLM inference operations.
9 */
10interface InferenceEngine {
11 /**
12 * Current state of the inference engine
13 */
14 val state: StateFlow<State>
15
16 /**
17 * Load a model from the given path.
18 *
19 * @throws UnsupportedArchitectureException if model architecture not supported
20 */
21 suspend fun loadModel(pathToModel: String)
22
23 /**
24 * Sends a system prompt to the loaded model
25 */
26 suspend fun setSystemPrompt(systemPrompt: String)
27
28 /**
29 * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
30 */
31 fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
32
33 /**
34 * Runs a benchmark with the specified parameters.
35 */
36 suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
37
38 /**
39 * Unloads the currently loaded model.
40 */
41 fun cleanUp()
42
43 /**
44 * Cleans up resources when the engine is no longer needed.
45 */
46 fun destroy()
47
48 /**
49 * States of the inference engine
50 */
51 sealed class State {
52 object Uninitialized : State()
53 object Initializing : State()
54 object Initialized : State()
55
56 object LoadingModel : State()
57 object UnloadingModel : State()
58 object ModelReady : State()
59
60 object Benchmarking : State()
61 object ProcessingSystemPrompt : State()
62 object ProcessingUserPrompt : State()
63
64 object Generating : State()
65
66 data class Error(val exception: Exception) : State()
67 }
68
69 companion object {
70 const val DEFAULT_PREDICT_LENGTH = 1024
71 }
72}
73
74val State.isUninterruptible
75 get() = this is State.Initializing ||
76 this is State.LoadingModel ||
77 this is State.UnloadingModel ||
78 this is State.Benchmarking ||
79 this is State.ProcessingSystemPrompt ||
80 this is State.ProcessingUserPrompt
81
82val State.isModelLoaded: Boolean
83 get() = this is State.ModelReady ||
84 this is State.Benchmarking ||
85 this is State.ProcessingSystemPrompt ||
86 this is State.ProcessingUserPrompt ||
87 this is State.Generating
88
89class UnsupportedArchitectureException : Exception()