1package com.arm.aichat
 2
 3import com.arm.aichat.InferenceEngine.State
 4import kotlinx.coroutines.flow.Flow
 5import kotlinx.coroutines.flow.StateFlow
 6
 7/**
 8 * Interface defining the core LLM inference operations.
 9 */
10interface InferenceEngine {
11    /**
12     * Current state of the inference engine
13     */
14    val state: StateFlow<State>
15
16    /**
17     * Load a model from the given path.
18     *
19     * @throws UnsupportedArchitectureException if model architecture not supported
20     */
21    suspend fun loadModel(pathToModel: String)
22
23    /**
24     * Sends a system prompt to the loaded model
25     */
26    suspend fun setSystemPrompt(systemPrompt: String)
27
28    /**
29     * Sends a user prompt to the loaded model and returns a Flow of generated tokens.
30     */
31    fun sendUserPrompt(message: String, predictLength: Int = DEFAULT_PREDICT_LENGTH): Flow<String>
32
33    /**
34     * Runs a benchmark with the specified parameters.
35     */
36    suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String
37
38    /**
39     * Unloads the currently loaded model.
40     */
41    fun cleanUp()
42
43    /**
44     * Cleans up resources when the engine is no longer needed.
45     */
46    fun destroy()
47
48    /**
49     * States of the inference engine
50     */
51    sealed class State {
52        object Uninitialized : State()
53        object Initializing : State()
54        object Initialized : State()
55
56        object LoadingModel : State()
57        object UnloadingModel : State()
58        object ModelReady : State()
59
60        object Benchmarking : State()
61        object ProcessingSystemPrompt : State()
62        object ProcessingUserPrompt : State()
63
64        object Generating : State()
65
66        data class Error(val exception: Exception) : State()
67    }
68
69    companion object {
70        const val DEFAULT_PREDICT_LENGTH = 1024
71    }
72}
73
74val State.isUninterruptible
75    get() = this is State.Initializing ||
76        this is State.LoadingModel ||
77        this is State.UnloadingModel ||
78        this is State.Benchmarking ||
79        this is State.ProcessingSystemPrompt ||
80        this is State.ProcessingUserPrompt
81
82val State.isModelLoaded: Boolean
83    get() = this is State.ModelReady ||
84        this is State.Benchmarking ||
85        this is State.ProcessingSystemPrompt ||
86        this is State.ProcessingUserPrompt ||
87        this is State.Generating
88
89class UnsupportedArchitectureException : Exception()