1#include "server-task.h"
  2#include "server-queue.h"
  3
  4#include "log.h"
  5
  6#include <chrono>
  7
  8#define QUE_INF(fmt, ...) LOG_INF("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
  9#define QUE_WRN(fmt, ...) LOG_WRN("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 10#define QUE_ERR(fmt, ...) LOG_ERR("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 11#define QUE_DBG(fmt, ...) LOG_DBG("que  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 12
 13#define RES_INF(fmt, ...) LOG_INF("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 14#define RES_WRN(fmt, ...) LOG_WRN("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 15#define RES_ERR(fmt, ...) LOG_ERR("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 16#define RES_DBG(fmt, ...) LOG_DBG("res  %12.*s: " fmt, 12, __func__, __VA_ARGS__)
 17
 18//
 19// server_queue
 20//
 21
 22int server_queue::post(server_task && task, bool front) {
 23    std::unique_lock<std::mutex> lock(mutex_tasks);
 24    GGML_ASSERT(task.id != -1);
 25    // if this is cancel task make sure to clean up pending tasks
 26    if (task.type == SERVER_TASK_TYPE_CANCEL) {
 27        cleanup_pending_task(task.id_target);
 28    }
 29    const int task_id = task.id;
 30    QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
 31    if (front) {
 32        queue_tasks.push_front(std::move(task));
 33    } else {
 34        queue_tasks.push_back(std::move(task));
 35    }
 36    time_last_task = ggml_time_ms();
 37    condition_tasks.notify_one();
 38    return task_id;
 39}
 40
 41int server_queue::post(std::vector<server_task> && tasks, bool front) {
 42    std::unique_lock<std::mutex> lock(mutex_tasks);
 43    for (auto & task : tasks) {
 44        if (task.id == -1) {
 45            task.id = id++;
 46        }
 47        // if this is cancel task make sure to clean up pending tasks
 48        if (task.type == SERVER_TASK_TYPE_CANCEL) {
 49            cleanup_pending_task(task.id_target);
 50        }
 51        QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
 52        if (front) {
 53            queue_tasks.push_front(std::move(task));
 54        } else {
 55            queue_tasks.push_back(std::move(task));
 56        }
 57    }
 58    time_last_task = ggml_time_ms();
 59    condition_tasks.notify_one();
 60    return 0;
 61}
 62
 63void server_queue::defer(server_task && task) {
 64    std::unique_lock<std::mutex> lock(mutex_tasks);
 65    QUE_DBG("defer task, id = %d\n", task.id);
 66    queue_tasks_deferred.push_back(std::move(task));
 67    time_last_task = ggml_time_ms();
 68    condition_tasks.notify_one();
 69}
 70
 71int server_queue::get_new_id() {
 72    std::unique_lock<std::mutex> lock(mutex_tasks);
 73    int new_id = id++;
 74    return new_id;
 75}
 76
 77void server_queue::pop_deferred_task(int id_slot) {
 78    std::unique_lock<std::mutex> lock(mutex_tasks);
 79    if (!queue_tasks_deferred.empty()) {
 80        // try to find a task that uses the specified slot
 81        bool found = false;
 82        for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
 83            if (it->id_slot == id_slot) {
 84                QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
 85                queue_tasks.emplace_front(std::move(*it));
 86                queue_tasks_deferred.erase(it);
 87                found = true;
 88                break;
 89            }
 90        }
 91        // if not tasks found using the slot, just pop the first deferred task (default behavior)
 92        if (!found) {
 93            QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
 94            queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
 95            queue_tasks_deferred.pop_front();
 96        }
 97    }
 98    time_last_task = ggml_time_ms();
 99    condition_tasks.notify_one();
100}
101
102void server_queue::wait_until_no_sleep() {
103    std::unique_lock<std::mutex> lock(mutex_tasks);
104    if (!sleeping) {
105        return;
106    } else {
107        if (!req_stop_sleeping) {
108            QUE_DBG("%s", "requesting to stop sleeping\n");
109            req_stop_sleeping = true;
110            condition_tasks.notify_one(); // only main thread is waiting on this
111        }
112        QUE_DBG("%s", "waiting until no sleep\n");
113        condition_tasks.wait(lock, [&]{
114            return !sleeping;
115        });
116    }
117}
118
119void server_queue::terminate() {
120    std::unique_lock<std::mutex> lock(mutex_tasks);
121    running = false;
122    condition_tasks.notify_all();
123}
124
125void server_queue::start_loop(int64_t idle_sleep_ms) {
126    running = true;
127    time_last_task = ggml_time_ms();
128
129    constexpr auto max_wait_time = std::chrono::seconds(1);
130    auto should_sleep = [&]() -> bool {
131        // caller must hold mutex_tasks
132        if (idle_sleep_ms < 0) {
133            return false;
134        }
135        int64_t now = ggml_time_ms();
136        return (now - time_last_task) >= idle_sleep_ms;
137    };
138
139    while (true) {
140        QUE_DBG("%s", "processing new tasks\n");
141
142        while (true) {
143            std::unique_lock<std::mutex> lock(mutex_tasks);
144            if (!running) {
145                QUE_DBG("%s", "terminate\n");
146                return;
147            }
148            if (queue_tasks.empty()) {
149                lock.unlock();
150                break;
151            }
152            server_task task = std::move(queue_tasks.front());
153            queue_tasks.pop_front();
154            lock.unlock();
155
156            QUE_DBG("processing task, id = %d\n", task.id);
157            callback_new_task(std::move(task));
158        }
159        // all tasks in the current loop is processed, slots data is now ready
160        QUE_DBG("%s", "update slots\n");
161
162        // this will run the main inference process for all slots
163        callback_update_slots();
164        {
165            // update_slots() may take a while to finish, we need to make sure it's not counted as idle
166            std::unique_lock<std::mutex> lock(mutex_tasks);
167            time_last_task = ggml_time_ms();
168        }
169
170        QUE_DBG("%s", "waiting for new tasks\n");
171        while (true) {
172            std::unique_lock<std::mutex> lock(mutex_tasks);
173            if (!running || !queue_tasks.empty()) {
174                break; // go back to process new tasks or terminate
175            }
176
177            // no tasks, check for sleeping state
178            if (should_sleep()) {
179                QUE_INF("%s", "entering sleeping state\n");
180                sleeping = true;
181                callback_sleeping_state(true);
182                req_stop_sleeping = false;
183                // wait until we are requested to exit sleeping state
184                condition_tasks.wait(lock, [&]{
185                    return (!running || req_stop_sleeping);
186                });
187                if (!running) { // may changed during sleep
188                    break; // terminate
189                }
190                QUE_INF("%s", "exiting sleeping state\n");
191                req_stop_sleeping = false;
192                callback_sleeping_state(false);
193                sleeping = false;
194                time_last_task = ggml_time_ms();
195                condition_tasks.notify_all(); // notify wait_until_no_sleep()
196                break; // process new tasks
197            } else {
198                // wait for new tasks or timeout for checking sleeping condition
199                bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
200                    return (!queue_tasks.empty() || !running);
201                });
202                if (res) {
203                    break; // new task arrived or terminate
204                }
205                // otherwise, loop again to check sleeping condition
206            }
207        }
208    }
209}
210
211void server_queue::cleanup_pending_task(int id_target) {
212    // no need lock because this is called exclusively by post()
213    auto rm_func = [id_target](const server_task & task) {
214        return task.id == id_target;
215    };
216    queue_tasks.erase(
217        std::remove_if(queue_tasks.begin(),          queue_tasks.end(),          rm_func),
218        queue_tasks.end());
219    queue_tasks_deferred.erase(
220        std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
221        queue_tasks_deferred.end());
222}
223
224//
225// server_response
226//
227
228void server_response::add_waiting_task_id(int id_task) {
229    RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
230
231    std::unique_lock<std::mutex> lock(mutex_results);
232    waiting_task_ids.insert(id_task);
233}
234
235void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
236    std::unique_lock<std::mutex> lock(mutex_results);
237
238    for (const auto & id_task : id_tasks) {
239        RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
240        waiting_task_ids.insert(id_task);
241    }
242}
243
244void server_response::remove_waiting_task_id(int id_task) {
245    RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
246
247    std::unique_lock<std::mutex> lock(mutex_results);
248    waiting_task_ids.erase(id_task);
249    // make sure to clean up all pending results
250    queue_results.erase(
251        std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
252            return res->id == id_task;
253        }),
254        queue_results.end());
255}
256
257void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
258    std::unique_lock<std::mutex> lock(mutex_results);
259
260    for (const auto & id_task : id_tasks) {
261        RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
262        waiting_task_ids.erase(id_task);
263    }
264}
265
266server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
267    while (true) {
268        std::unique_lock<std::mutex> lock(mutex_results);
269        condition_results.wait(lock, [&]{
270            if (!running) {
271                RES_DBG("%s : queue result stop\n", "recv");
272                std::terminate(); // we cannot return here since the caller is HTTP code
273            }
274            return !queue_results.empty();
275        });
276
277        for (size_t i = 0; i < queue_results.size(); i++) {
278            if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
279                server_task_result_ptr res = std::move(queue_results[i]);
280                queue_results.erase(queue_results.begin() + i);
281                return res;
282            }
283        }
284    }
285
286    // should never reach here
287}
288
289server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
290    while (true) {
291        std::unique_lock<std::mutex> lock(mutex_results);
292
293        for (int i = 0; i < (int) queue_results.size(); i++) {
294            if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
295                server_task_result_ptr res = std::move(queue_results[i]);
296                queue_results.erase(queue_results.begin() + i);
297                return res;
298            }
299        }
300
301        std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
302        if (!running) {
303            RES_DBG("%s : queue result stop\n", __func__);
304            std::terminate(); // we cannot return here since the caller is HTTP code
305        }
306        if (cr_res == std::cv_status::timeout) {
307            return nullptr;
308        }
309    }
310
311    // should never reach here
312}
313
314server_task_result_ptr server_response::recv(int id_task) {
315    std::unordered_set<int> id_tasks = {id_task};
316    return recv(id_tasks);
317}
318
319void server_response::send(server_task_result_ptr && result) {
320    RES_DBG("sending result for task id = %d\n", result->id);
321
322    std::unique_lock<std::mutex> lock(mutex_results);
323    for (const auto & id_task : waiting_task_ids) {
324        if (result->id == id_task) {
325            RES_DBG("task id = %d pushed to result queue\n", result->id);
326
327            queue_results.emplace_back(std::move(result));
328            condition_results.notify_all();
329            return;
330        }
331    }
332}
333
334void server_response::terminate() {
335    running = false;
336    condition_results.notify_all();
337}
338
339//
340// server_response_reader
341//
342
343void server_response_reader::post_task(server_task && task, bool front) {
344    GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
345    GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
346    task.index = 0;
347    id_tasks.insert(task.id);
348    states.push_back(task.create_state());
349    queue_results.add_waiting_task_id(task.id);
350    queue_tasks.post(std::move(task), front);
351}
352
353void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
354    GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
355    id_tasks = server_task::get_list_id(tasks);
356    states.reserve(tasks.size());
357    size_t index = 0;
358    for (auto & task : tasks) {
359        task.index = index++;
360        states.push_back(task.create_state());
361        // for child tasks
362        for (auto & child_task : task.child_tasks) {
363            child_task.index = index++;
364            states.push_back(child_task.create_state());
365        }
366    }
367    GGML_ASSERT(states.size() == id_tasks.size());
368    queue_results.add_waiting_task_ids(id_tasks);
369    queue_tasks.post(std::move(tasks), front);
370}
371
372bool server_response_reader::has_next() const {
373    return !cancelled && received_count < id_tasks.size();
374}
375
376// return nullptr if should_stop() is true before receiving a result
377// note: if one error is received, it will stop further processing and return error result
378server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
379    while (true) {
380        server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
381        if (result == nullptr) {
382            // timeout, check stop condition
383            if (should_stop()) {
384                SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
385                return nullptr;
386            }
387        } else {
388            if (result->is_error()) {
389                stop(); // cancel remaining tasks
390                SRV_DBG("%s", "received error result, stopping further processing\n");
391                return result;
392            }
393            if (!states.empty()) {
394                // update the generation state if needed
395                const size_t idx = result->index;
396                GGML_ASSERT(idx < states.size());
397                result->update(states[idx]);
398            }
399            if (result->is_stop()) {
400                received_count++;
401            }
402            return result;
403        }
404    }
405
406    // should not reach here
407}
408
409server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
410    batch_response batch_res;
411    batch_res.results.clear();
412    batch_res.results.resize(id_tasks.size());
413    while (has_next()) {
414        auto res = next(should_stop);
415        if (res == nullptr) {
416            batch_res.is_terminated = true;
417            return batch_res;
418        }
419        if (res->is_error()) {
420            batch_res.error = std::move(res);
421            return batch_res;
422        }
423        const size_t idx = res->index;
424        GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
425        GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
426        batch_res.results[idx] = std::move(res);
427    }
428    return batch_res;
429}
430
431void server_response_reader::stop() {
432    queue_results.remove_waiting_task_ids(id_tasks);
433    if (has_next() && !cancelled) {
434        // if tasks is not finished yet, cancel them
435        cancelled = true;
436        std::vector<server_task> cancel_tasks;
437        cancel_tasks.reserve(id_tasks.size());
438        for (const auto & id_task : id_tasks) {
439            SRV_WRN("cancel task, id_task = %d\n", id_task);
440            server_task task(SERVER_TASK_TYPE_CANCEL);
441            task.id_target = id_task;
442            queue_results.remove_waiting_task_id(id_task);
443            cancel_tasks.push_back(std::move(task));
444        }
445        // push to beginning of the queue, so it has highest priority
446        queue_tasks.post(std::move(cancel_tasks), true);
447    } else {
448        SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
449    }
450}