1#pragma once
  2
  3#include "server-task.h"
  4
  5#include <condition_variable>
  6#include <deque>
  7#include <mutex>
  8#include <vector>
  9#include <unordered_set>
 10
 11// struct for managing server tasks
 12// in most cases, use server_response_reader to post new tasks and retrieve results
 13struct server_queue {
 14private:
 15    int id = 0;
 16    bool running  = false;
 17    bool sleeping = false;
 18    bool req_stop_sleeping = false;
 19    int64_t time_last_task = 0;
 20
 21    // queues
 22    std::deque<server_task> queue_tasks;
 23    std::deque<server_task> queue_tasks_deferred;
 24
 25    std::mutex mutex_tasks;
 26    std::condition_variable condition_tasks;
 27
 28    // callback functions
 29    std::function<void(server_task &&)> callback_new_task;
 30    std::function<void(void)>           callback_update_slots;
 31    std::function<void(bool)>           callback_sleeping_state;
 32
 33public:
 34    // Add a new task to the end of the queue
 35    int post(server_task && task, bool front = false);
 36
 37    // multi-task version of post()
 38    int post(std::vector<server_task> && tasks, bool front = false);
 39
 40    // Add a new task, but defer until one slot is available
 41    void defer(server_task && task);
 42
 43    // Get the next id for creating a new task
 44    int get_new_id();
 45
 46    // Call when the state of one slot is changed, it will move one task from deferred to main queue
 47    // prioritize tasks that use the specified slot (otherwise, pop the first deferred task)
 48    void pop_deferred_task(int id_slot);
 49
 50    // if sleeping, request exiting sleep state and wait until it is done
 51    // returns immediately if not sleeping
 52    void wait_until_no_sleep();
 53
 54    bool is_sleeping() {
 55        std::unique_lock<std::mutex> lock(mutex_tasks);
 56        return sleeping;
 57    }
 58
 59    // end the start_loop routine
 60    void terminate();
 61
 62    /**
 63     * Main loop consists of these steps:
 64     * - Wait until a new task arrives
 65     * - Process the task (i.e. maybe copy data into slot)
 66     * - Check if multitask is finished
 67     * - Update all slots
 68     *
 69     * Sleeping procedure (disabled if idle_sleep_ms < 0):
 70     * - If there is no task after idle_sleep_ms, enter sleeping state
 71     * - Call callback_sleeping_state(true)
 72     * - Wait until req_stop_sleeping is set to true
 73     * - Call callback_sleeping_state(false)
 74     * - Exit sleeping state
 75     */
 76    void start_loop(int64_t idle_sleep_ms = -1);
 77
 78    // for metrics
 79    size_t queue_tasks_deferred_size() {
 80        std::unique_lock<std::mutex> lock(mutex_tasks);
 81        return queue_tasks_deferred.size();
 82    }
 83
 84    //
 85    // Functions below are not thread-safe, must only be used before start_loop() is called
 86    //
 87
 88    // Register function to process a new task
 89    void on_new_task(std::function<void(server_task &&)> callback) {
 90        callback_new_task = std::move(callback);
 91    }
 92
 93    // Register the function to be called when all slots data is ready to be processed
 94    void on_update_slots(std::function<void(void)> callback) {
 95        callback_update_slots = std::move(callback);
 96    }
 97
 98    // Register callback for sleeping state change
 99    // note: when entering sleeping state, the callback is called AFTER sleeping is set to true
100    //       when leaving sleeping state, the callback is called BEFORE sleeping is set to false
101    void on_sleeping_state(std::function<void(bool)> callback) {
102        callback_sleeping_state = std::move(callback);
103    }
104
105private:
106    void cleanup_pending_task(int id_target);
107};
108
109// struct for managing server responses
110// in most cases, use server_response_reader to retrieve results
111struct server_response {
112private:
113    bool running = true;
114
115    // for keeping track of all tasks waiting for the result
116    std::unordered_set<int> waiting_task_ids;
117
118    // the main result queue (using ptr for polymorphism)
119    std::vector<server_task_result_ptr> queue_results;
120
121    std::mutex mutex_results;
122    std::condition_variable condition_results;
123
124public:
125    // add the id_task to the list of tasks waiting for response
126    void add_waiting_task_id(int id_task);
127
128    void add_waiting_task_ids(const std::unordered_set<int> & id_tasks);
129
130    // when the request is finished, we can remove task associated with it
131    void remove_waiting_task_id(int id_task);
132
133    // remove multiple tasks from waiting list
134    void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks);
135
136    // This function blocks the thread until there is a response for one of the id_tasks
137    server_task_result_ptr recv(const std::unordered_set<int> & id_tasks);
138
139    // same as recv(), but have timeout in seconds
140    // if timeout is reached, nullptr is returned
141    server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout);
142
143    // single-task version of recv()
144    server_task_result_ptr recv(int id_task);
145
146    // Send a new result to a waiting id_task
147    void send(server_task_result_ptr && result);
148
149    // terminate the waiting loop
150    void terminate();
151};
152
153// utility class to make working with server_queue and server_response easier
154// it provides a generator-like API for server responses
155// support pooling connection state and aggregating multiple results
156struct server_response_reader {
157    std::unordered_set<int> id_tasks;
158    server_queue & queue_tasks;
159    server_response & queue_results;
160    size_t received_count = 0;
161    bool cancelled = false;
162    int polling_interval_seconds;
163
164    // tracking generation state and partial tool calls
165    // only used by streaming completions
166    std::vector<task_result_state> states;
167
168    // should_stop function will be called each polling_interval_seconds
169    server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds)
170        : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {}
171    ~server_response_reader() {
172        stop();
173    }
174
175    int get_new_id() {
176        return queue_tasks.get_new_id();
177    }
178
179    // if front = true, the task will be posted to the front of the queue (high priority)
180    void post_task(server_task && task, bool front = false);
181    void post_tasks(std::vector<server_task> && tasks, bool front = false);
182    bool has_next() const;
183
184    // return nullptr if should_stop() is true before receiving a result
185    // note: if one error is received, it will stop further processing and return error result
186    server_task_result_ptr next(const std::function<bool()> & should_stop);
187
188    struct batch_response {
189        bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
190        std::vector<server_task_result_ptr> results;
191        server_task_result_ptr error; // nullptr if no error
192    };
193    // aggregate multiple results
194    batch_response wait_for_all(const std::function<bool()> & should_stop);
195
196    void stop();
197};