summaryrefslogtreecommitdiff
path: root/llama.cpp/tools/server/server-queue.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/tools/server/server-queue.cpp')
-rw-r--r--llama.cpp/tools/server/server-queue.cpp450
1 files changed, 450 insertions, 0 deletions
diff --git a/llama.cpp/tools/server/server-queue.cpp b/llama.cpp/tools/server/server-queue.cpp
new file mode 100644
index 0000000..a2a026a
--- /dev/null
+++ b/llama.cpp/tools/server/server-queue.cpp
@@ -0,0 +1,450 @@
+#include "server-task.h"
+#include "server-queue.h"
+
+#include "log.h"
+
+#include <chrono>
+
+#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
+
+//
+// server_queue
+//
+
+int server_queue::post(server_task && task, bool front) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ GGML_ASSERT(task.id != -1);
+ // if this is cancel task make sure to clean up pending tasks
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
+ cleanup_pending_task(task.id_target);
+ }
+ const int task_id = task.id;
+ QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
+ if (front) {
+ queue_tasks.push_front(std::move(task));
+ } else {
+ queue_tasks.push_back(std::move(task));
+ }
+ time_last_task = ggml_time_ms();
+ condition_tasks.notify_one();
+ return task_id;
+}
+
+int server_queue::post(std::vector<server_task> && tasks, bool front) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ for (auto & task : tasks) {
+ if (task.id == -1) {
+ task.id = id++;
+ }
+ // if this is cancel task make sure to clean up pending tasks
+ if (task.type == SERVER_TASK_TYPE_CANCEL) {
+ cleanup_pending_task(task.id_target);
+ }
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
+ if (front) {
+ queue_tasks.push_front(std::move(task));
+ } else {
+ queue_tasks.push_back(std::move(task));
+ }
+ }
+ time_last_task = ggml_time_ms();
+ condition_tasks.notify_one();
+ return 0;
+}
+
+void server_queue::defer(server_task && task) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ QUE_DBG("defer task, id = %d\n", task.id);
+ queue_tasks_deferred.push_back(std::move(task));
+ time_last_task = ggml_time_ms();
+ condition_tasks.notify_one();
+}
+
+int server_queue::get_new_id() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ int new_id = id++;
+ return new_id;
+}
+
+void server_queue::pop_deferred_task(int id_slot) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!queue_tasks_deferred.empty()) {
+ // try to find a task that uses the specified slot
+ bool found = false;
+ for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
+ if (it->id_slot == id_slot) {
+ QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
+ queue_tasks.emplace_front(std::move(*it));
+ queue_tasks_deferred.erase(it);
+ found = true;
+ break;
+ }
+ }
+ // if not tasks found using the slot, just pop the first deferred task (default behavior)
+ if (!found) {
+ QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
+ queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
+ queue_tasks_deferred.pop_front();
+ }
+ }
+ time_last_task = ggml_time_ms();
+ condition_tasks.notify_one();
+}
+
+void server_queue::wait_until_no_sleep() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!sleeping) {
+ return;
+ } else {
+ if (!req_stop_sleeping) {
+ QUE_DBG("%s", "requesting to stop sleeping\n");
+ req_stop_sleeping = true;
+ condition_tasks.notify_one(); // only main thread is waiting on this
+ }
+ QUE_DBG("%s", "waiting until no sleep\n");
+ condition_tasks.wait(lock, [&]{
+ return !sleeping;
+ });
+ }
+}
+
+void server_queue::terminate() {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ running = false;
+ condition_tasks.notify_all();
+}
+
+void server_queue::start_loop(int64_t idle_sleep_ms) {
+ running = true;
+ time_last_task = ggml_time_ms();
+
+ constexpr auto max_wait_time = std::chrono::seconds(1);
+ auto should_sleep = [&]() -> bool {
+ // caller must hold mutex_tasks
+ if (idle_sleep_ms < 0) {
+ return false;
+ }
+ int64_t now = ggml_time_ms();
+ return (now - time_last_task) >= idle_sleep_ms;
+ };
+
+ while (true) {
+ QUE_DBG("%s", "processing new tasks\n");
+
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!running) {
+ QUE_DBG("%s", "terminate\n");
+ return;
+ }
+ if (queue_tasks.empty()) {
+ lock.unlock();
+ break;
+ }
+ server_task task = std::move(queue_tasks.front());
+ queue_tasks.pop_front();
+ lock.unlock();
+
+ QUE_DBG("processing task, id = %d\n", task.id);
+ callback_new_task(std::move(task));
+ }
+ // all tasks in the current loop is processed, slots data is now ready
+ QUE_DBG("%s", "update slots\n");
+
+ // this will run the main inference process for all slots
+ callback_update_slots();
+ {
+ // update_slots() may take a while to finish, we need to make sure it's not counted as idle
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ time_last_task = ggml_time_ms();
+ }
+
+ QUE_DBG("%s", "waiting for new tasks\n");
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ if (!running || !queue_tasks.empty()) {
+ break; // go back to process new tasks or terminate
+ }
+
+ // no tasks, check for sleeping state
+ if (should_sleep()) {
+ QUE_INF("%s", "entering sleeping state\n");
+ sleeping = true;
+ callback_sleeping_state(true);
+ req_stop_sleeping = false;
+ // wait until we are requested to exit sleeping state
+ condition_tasks.wait(lock, [&]{
+ return (!running || req_stop_sleeping);
+ });
+ if (!running) { // may changed during sleep
+ break; // terminate
+ }
+ QUE_INF("%s", "exiting sleeping state\n");
+ req_stop_sleeping = false;
+ callback_sleeping_state(false);
+ sleeping = false;
+ time_last_task = ggml_time_ms();
+ condition_tasks.notify_all(); // notify wait_until_no_sleep()
+ break; // process new tasks
+ } else {
+ // wait for new tasks or timeout for checking sleeping condition
+ bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
+ return (!queue_tasks.empty() || !running);
+ });
+ if (res) {
+ break; // new task arrived or terminate
+ }
+ // otherwise, loop again to check sleeping condition
+ }
+ }
+ }
+}
+
+void server_queue::cleanup_pending_task(int id_target) {
+ // no need lock because this is called exclusively by post()
+ auto rm_func = [id_target](const server_task & task) {
+ return task.id == id_target;
+ };
+ queue_tasks.erase(
+ std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
+ queue_tasks.end());
+ queue_tasks_deferred.erase(
+ std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
+ queue_tasks_deferred.end());
+}
+
+//
+// server_response
+//
+
+void server_response::add_waiting_task_id(int id_task) {
+ RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
+
+ std::unique_lock<std::mutex> lock(mutex_results);
+ waiting_task_ids.insert(id_task);
+}
+
+void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+
+ for (const auto & id_task : id_tasks) {
+ RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
+ waiting_task_ids.insert(id_task);
+ }
+}
+
+void server_response::remove_waiting_task_id(int id_task) {
+ RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+
+ std::unique_lock<std::mutex> lock(mutex_results);
+ waiting_task_ids.erase(id_task);
+ // make sure to clean up all pending results
+ queue_results.erase(
+ std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
+ return res->id == id_task;
+ }),
+ queue_results.end());
+}
+
+void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+
+ for (const auto & id_task : id_tasks) {
+ RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
+ waiting_task_ids.erase(id_task);
+ }
+}
+
+server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+ condition_results.wait(lock, [&]{
+ if (!running) {
+ RES_DBG("%s : queue result stop\n", "recv");
+ std::terminate(); // we cannot return here since the caller is HTTP code
+ }
+ return !queue_results.empty();
+ });
+
+ for (size_t i = 0; i < queue_results.size(); i++) {
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+ server_task_result_ptr res = std::move(queue_results[i]);
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+ }
+
+ // should never reach here
+}
+
+server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
+ while (true) {
+ std::unique_lock<std::mutex> lock(mutex_results);
+
+ for (int i = 0; i < (int) queue_results.size(); i++) {
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
+ server_task_result_ptr res = std::move(queue_results[i]);
+ queue_results.erase(queue_results.begin() + i);
+ return res;
+ }
+ }
+
+ std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
+ if (!running) {
+ RES_DBG("%s : queue result stop\n", __func__);
+ std::terminate(); // we cannot return here since the caller is HTTP code
+ }
+ if (cr_res == std::cv_status::timeout) {
+ return nullptr;
+ }
+ }
+
+ // should never reach here
+}
+
+server_task_result_ptr server_response::recv(int id_task) {
+ std::unordered_set<int> id_tasks = {id_task};
+ return recv(id_tasks);
+}
+
+void server_response::send(server_task_result_ptr && result) {
+ RES_DBG("sending result for task id = %d\n", result->id);
+
+ std::unique_lock<std::mutex> lock(mutex_results);
+ for (const auto & id_task : waiting_task_ids) {
+ if (result->id == id_task) {
+ RES_DBG("task id = %d pushed to result queue\n", result->id);
+
+ queue_results.emplace_back(std::move(result));
+ condition_results.notify_all();
+ return;
+ }
+ }
+}
+
+void server_response::terminate() {
+ running = false;
+ condition_results.notify_all();
+}
+
+//
+// server_response_reader
+//
+
+void server_response_reader::post_task(server_task && task, bool front) {
+ GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
+ GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
+ task.index = 0;
+ id_tasks.insert(task.id);
+ states.push_back(task.create_state());
+ queue_results.add_waiting_task_id(task.id);
+ queue_tasks.post(std::move(task), front);
+}
+
+void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
+ GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
+ id_tasks = server_task::get_list_id(tasks);
+ states.reserve(tasks.size());
+ size_t index = 0;
+ for (auto & task : tasks) {
+ task.index = index++;
+ states.push_back(task.create_state());
+ // for child tasks
+ for (auto & child_task : task.child_tasks) {
+ child_task.index = index++;
+ states.push_back(child_task.create_state());
+ }
+ }
+ GGML_ASSERT(states.size() == id_tasks.size());
+ queue_results.add_waiting_task_ids(id_tasks);
+ queue_tasks.post(std::move(tasks), front);
+}
+
+bool server_response_reader::has_next() const {
+ return !cancelled && received_count < id_tasks.size();
+}
+
+// return nullptr if should_stop() is true before receiving a result
+// note: if one error is received, it will stop further processing and return error result
+server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
+ while (true) {
+ server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
+ if (result == nullptr) {
+ // timeout, check stop condition
+ if (should_stop()) {
+ SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
+ return nullptr;
+ }
+ } else {
+ if (result->is_error()) {
+ stop(); // cancel remaining tasks
+ SRV_DBG("%s", "received error result, stopping further processing\n");
+ return result;
+ }
+ if (!states.empty()) {
+ // update the generation state if needed
+ const size_t idx = result->index;
+ GGML_ASSERT(idx < states.size());
+ result->update(states[idx]);
+ }
+ if (result->is_stop()) {
+ received_count++;
+ }
+ return result;
+ }
+ }
+
+ // should not reach here
+}
+
+server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
+ batch_response batch_res;
+ batch_res.results.clear();
+ batch_res.results.resize(id_tasks.size());
+ while (has_next()) {
+ auto res = next(should_stop);
+ if (res == nullptr) {
+ batch_res.is_terminated = true;
+ return batch_res;
+ }
+ if (res->is_error()) {
+ batch_res.error = std::move(res);
+ return batch_res;
+ }
+ const size_t idx = res->index;
+ GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
+ GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
+ batch_res.results[idx] = std::move(res);
+ }
+ return batch_res;
+}
+
+void server_response_reader::stop() {
+ queue_results.remove_waiting_task_ids(id_tasks);
+ if (has_next() && !cancelled) {
+ // if tasks is not finished yet, cancel them
+ cancelled = true;
+ std::vector<server_task> cancel_tasks;
+ cancel_tasks.reserve(id_tasks.size());
+ for (const auto & id_task : id_tasks) {
+ SRV_WRN("cancel task, id_task = %d\n", id_task);
+ server_task task(SERVER_TASK_TYPE_CANCEL);
+ task.id_target = id_task;
+ queue_results.remove_waiting_task_id(id_task);
+ cancel_tasks.push_back(std::move(task));
+ }
+ // push to beginning of the queue, so it has highest priority
+ queue_tasks.post(std::move(cancel_tasks), true);
+ } else {
+ SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
+ }
+}