1#include "server-task.h"
2#include "server-queue.h"
3
4#include "log.h"
5
6#include <chrono>
7
8#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
9#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
10#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
11#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
12
13#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
14#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
15#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
16#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__)
17
18//
19// server_queue
20//
21
22int server_queue::post(server_task && task, bool front) {
23 std::unique_lock<std::mutex> lock(mutex_tasks);
24 GGML_ASSERT(task.id != -1);
25 // if this is cancel task make sure to clean up pending tasks
26 if (task.type == SERVER_TASK_TYPE_CANCEL) {
27 cleanup_pending_task(task.id_target);
28 }
29 const int task_id = task.id;
30 QUE_DBG("new task, id = %d, front = %d\n", task_id, front);
31 if (front) {
32 queue_tasks.push_front(std::move(task));
33 } else {
34 queue_tasks.push_back(std::move(task));
35 }
36 time_last_task = ggml_time_ms();
37 condition_tasks.notify_one();
38 return task_id;
39}
40
41int server_queue::post(std::vector<server_task> && tasks, bool front) {
42 std::unique_lock<std::mutex> lock(mutex_tasks);
43 for (auto & task : tasks) {
44 if (task.id == -1) {
45 task.id = id++;
46 }
47 // if this is cancel task make sure to clean up pending tasks
48 if (task.type == SERVER_TASK_TYPE_CANCEL) {
49 cleanup_pending_task(task.id_target);
50 }
51 QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
52 if (front) {
53 queue_tasks.push_front(std::move(task));
54 } else {
55 queue_tasks.push_back(std::move(task));
56 }
57 }
58 time_last_task = ggml_time_ms();
59 condition_tasks.notify_one();
60 return 0;
61}
62
63void server_queue::defer(server_task && task) {
64 std::unique_lock<std::mutex> lock(mutex_tasks);
65 QUE_DBG("defer task, id = %d\n", task.id);
66 queue_tasks_deferred.push_back(std::move(task));
67 time_last_task = ggml_time_ms();
68 condition_tasks.notify_one();
69}
70
71int server_queue::get_new_id() {
72 std::unique_lock<std::mutex> lock(mutex_tasks);
73 int new_id = id++;
74 return new_id;
75}
76
77void server_queue::pop_deferred_task(int id_slot) {
78 std::unique_lock<std::mutex> lock(mutex_tasks);
79 if (!queue_tasks_deferred.empty()) {
80 // try to find a task that uses the specified slot
81 bool found = false;
82 for (auto it = queue_tasks_deferred.begin(); it != queue_tasks_deferred.end(); ++it) {
83 if (it->id_slot == id_slot) {
84 QUE_DBG("pop deferred task (use slot %d), id_task = %d\n", id_slot, it->id);
85 queue_tasks.emplace_front(std::move(*it));
86 queue_tasks_deferred.erase(it);
87 found = true;
88 break;
89 }
90 }
91 // if not tasks found using the slot, just pop the first deferred task (default behavior)
92 if (!found) {
93 QUE_DBG("pop deferred task, id_task = %d\n", queue_tasks_deferred.front().id);
94 queue_tasks.emplace_front(std::move(queue_tasks_deferred.front()));
95 queue_tasks_deferred.pop_front();
96 }
97 }
98 time_last_task = ggml_time_ms();
99 condition_tasks.notify_one();
100}
101
102void server_queue::wait_until_no_sleep() {
103 std::unique_lock<std::mutex> lock(mutex_tasks);
104 if (!sleeping) {
105 return;
106 } else {
107 if (!req_stop_sleeping) {
108 QUE_DBG("%s", "requesting to stop sleeping\n");
109 req_stop_sleeping = true;
110 condition_tasks.notify_one(); // only main thread is waiting on this
111 }
112 QUE_DBG("%s", "waiting until no sleep\n");
113 condition_tasks.wait(lock, [&]{
114 return !sleeping;
115 });
116 }
117}
118
119void server_queue::terminate() {
120 std::unique_lock<std::mutex> lock(mutex_tasks);
121 running = false;
122 condition_tasks.notify_all();
123}
124
125void server_queue::start_loop(int64_t idle_sleep_ms) {
126 running = true;
127 time_last_task = ggml_time_ms();
128
129 constexpr auto max_wait_time = std::chrono::seconds(1);
130 auto should_sleep = [&]() -> bool {
131 // caller must hold mutex_tasks
132 if (idle_sleep_ms < 0) {
133 return false;
134 }
135 int64_t now = ggml_time_ms();
136 return (now - time_last_task) >= idle_sleep_ms;
137 };
138
139 while (true) {
140 QUE_DBG("%s", "processing new tasks\n");
141
142 while (true) {
143 std::unique_lock<std::mutex> lock(mutex_tasks);
144 if (!running) {
145 QUE_DBG("%s", "terminate\n");
146 return;
147 }
148 if (queue_tasks.empty()) {
149 lock.unlock();
150 break;
151 }
152 server_task task = std::move(queue_tasks.front());
153 queue_tasks.pop_front();
154 lock.unlock();
155
156 QUE_DBG("processing task, id = %d\n", task.id);
157 callback_new_task(std::move(task));
158 }
159 // all tasks in the current loop is processed, slots data is now ready
160 QUE_DBG("%s", "update slots\n");
161
162 // this will run the main inference process for all slots
163 callback_update_slots();
164 {
165 // update_slots() may take a while to finish, we need to make sure it's not counted as idle
166 std::unique_lock<std::mutex> lock(mutex_tasks);
167 time_last_task = ggml_time_ms();
168 }
169
170 QUE_DBG("%s", "waiting for new tasks\n");
171 while (true) {
172 std::unique_lock<std::mutex> lock(mutex_tasks);
173 if (!running || !queue_tasks.empty()) {
174 break; // go back to process new tasks or terminate
175 }
176
177 // no tasks, check for sleeping state
178 if (should_sleep()) {
179 QUE_INF("%s", "entering sleeping state\n");
180 sleeping = true;
181 callback_sleeping_state(true);
182 req_stop_sleeping = false;
183 // wait until we are requested to exit sleeping state
184 condition_tasks.wait(lock, [&]{
185 return (!running || req_stop_sleeping);
186 });
187 if (!running) { // may changed during sleep
188 break; // terminate
189 }
190 QUE_INF("%s", "exiting sleeping state\n");
191 req_stop_sleeping = false;
192 callback_sleeping_state(false);
193 sleeping = false;
194 time_last_task = ggml_time_ms();
195 condition_tasks.notify_all(); // notify wait_until_no_sleep()
196 break; // process new tasks
197 } else {
198 // wait for new tasks or timeout for checking sleeping condition
199 bool res = condition_tasks.wait_for(lock, max_wait_time, [&]{
200 return (!queue_tasks.empty() || !running);
201 });
202 if (res) {
203 break; // new task arrived or terminate
204 }
205 // otherwise, loop again to check sleeping condition
206 }
207 }
208 }
209}
210
211void server_queue::cleanup_pending_task(int id_target) {
212 // no need lock because this is called exclusively by post()
213 auto rm_func = [id_target](const server_task & task) {
214 return task.id == id_target;
215 };
216 queue_tasks.erase(
217 std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func),
218 queue_tasks.end());
219 queue_tasks_deferred.erase(
220 std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func),
221 queue_tasks_deferred.end());
222}
223
224//
225// server_response
226//
227
228void server_response::add_waiting_task_id(int id_task) {
229 RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
230
231 std::unique_lock<std::mutex> lock(mutex_results);
232 waiting_task_ids.insert(id_task);
233}
234
235void server_response::add_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
236 std::unique_lock<std::mutex> lock(mutex_results);
237
238 for (const auto & id_task : id_tasks) {
239 RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
240 waiting_task_ids.insert(id_task);
241 }
242}
243
244void server_response::remove_waiting_task_id(int id_task) {
245 RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
246
247 std::unique_lock<std::mutex> lock(mutex_results);
248 waiting_task_ids.erase(id_task);
249 // make sure to clean up all pending results
250 queue_results.erase(
251 std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) {
252 return res->id == id_task;
253 }),
254 queue_results.end());
255}
256
257void server_response::remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
258 std::unique_lock<std::mutex> lock(mutex_results);
259
260 for (const auto & id_task : id_tasks) {
261 RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
262 waiting_task_ids.erase(id_task);
263 }
264}
265
266server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_tasks) {
267 while (true) {
268 std::unique_lock<std::mutex> lock(mutex_results);
269 condition_results.wait(lock, [&]{
270 if (!running) {
271 RES_DBG("%s : queue result stop\n", "recv");
272 std::terminate(); // we cannot return here since the caller is HTTP code
273 }
274 return !queue_results.empty();
275 });
276
277 for (size_t i = 0; i < queue_results.size(); i++) {
278 if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
279 server_task_result_ptr res = std::move(queue_results[i]);
280 queue_results.erase(queue_results.begin() + i);
281 return res;
282 }
283 }
284 }
285
286 // should never reach here
287}
288
289server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
290 while (true) {
291 std::unique_lock<std::mutex> lock(mutex_results);
292
293 for (int i = 0; i < (int) queue_results.size(); i++) {
294 if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
295 server_task_result_ptr res = std::move(queue_results[i]);
296 queue_results.erase(queue_results.begin() + i);
297 return res;
298 }
299 }
300
301 std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout));
302 if (!running) {
303 RES_DBG("%s : queue result stop\n", __func__);
304 std::terminate(); // we cannot return here since the caller is HTTP code
305 }
306 if (cr_res == std::cv_status::timeout) {
307 return nullptr;
308 }
309 }
310
311 // should never reach here
312}
313
314server_task_result_ptr server_response::recv(int id_task) {
315 std::unordered_set<int> id_tasks = {id_task};
316 return recv(id_tasks);
317}
318
319void server_response::send(server_task_result_ptr && result) {
320 RES_DBG("sending result for task id = %d\n", result->id);
321
322 std::unique_lock<std::mutex> lock(mutex_results);
323 for (const auto & id_task : waiting_task_ids) {
324 if (result->id == id_task) {
325 RES_DBG("task id = %d pushed to result queue\n", result->id);
326
327 queue_results.emplace_back(std::move(result));
328 condition_results.notify_all();
329 return;
330 }
331 }
332}
333
334void server_response::terminate() {
335 running = false;
336 condition_results.notify_all();
337}
338
339//
340// server_response_reader
341//
342
343void server_response_reader::post_task(server_task && task, bool front) {
344 GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader");
345 GGML_ASSERT(!task.is_parent() && "not supported, use post_tasks() instead");
346 task.index = 0;
347 id_tasks.insert(task.id);
348 states.push_back(task.create_state());
349 queue_results.add_waiting_task_id(task.id);
350 queue_tasks.post(std::move(task), front);
351}
352
353void server_response_reader::post_tasks(std::vector<server_task> && tasks, bool front) {
354 GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader");
355 id_tasks = server_task::get_list_id(tasks);
356 states.reserve(tasks.size());
357 size_t index = 0;
358 for (auto & task : tasks) {
359 task.index = index++;
360 states.push_back(task.create_state());
361 // for child tasks
362 for (auto & child_task : task.child_tasks) {
363 child_task.index = index++;
364 states.push_back(child_task.create_state());
365 }
366 }
367 GGML_ASSERT(states.size() == id_tasks.size());
368 queue_results.add_waiting_task_ids(id_tasks);
369 queue_tasks.post(std::move(tasks), front);
370}
371
372bool server_response_reader::has_next() const {
373 return !cancelled && received_count < id_tasks.size();
374}
375
376// return nullptr if should_stop() is true before receiving a result
377// note: if one error is received, it will stop further processing and return error result
378server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
379 while (true) {
380 server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
381 if (result == nullptr) {
382 // timeout, check stop condition
383 if (should_stop()) {
384 SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
385 return nullptr;
386 }
387 } else {
388 if (result->is_error()) {
389 stop(); // cancel remaining tasks
390 SRV_DBG("%s", "received error result, stopping further processing\n");
391 return result;
392 }
393 if (!states.empty()) {
394 // update the generation state if needed
395 const size_t idx = result->index;
396 GGML_ASSERT(idx < states.size());
397 result->update(states[idx]);
398 }
399 if (result->is_stop()) {
400 received_count++;
401 }
402 return result;
403 }
404 }
405
406 // should not reach here
407}
408
409server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
410 batch_response batch_res;
411 batch_res.results.clear();
412 batch_res.results.resize(id_tasks.size());
413 while (has_next()) {
414 auto res = next(should_stop);
415 if (res == nullptr) {
416 batch_res.is_terminated = true;
417 return batch_res;
418 }
419 if (res->is_error()) {
420 batch_res.error = std::move(res);
421 return batch_res;
422 }
423 const size_t idx = res->index;
424 GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
425 GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
426 batch_res.results[idx] = std::move(res);
427 }
428 return batch_res;
429}
430
431void server_response_reader::stop() {
432 queue_results.remove_waiting_task_ids(id_tasks);
433 if (has_next() && !cancelled) {
434 // if tasks is not finished yet, cancel them
435 cancelled = true;
436 std::vector<server_task> cancel_tasks;
437 cancel_tasks.reserve(id_tasks.size());
438 for (const auto & id_task : id_tasks) {
439 SRV_WRN("cancel task, id_task = %d\n", id_task);
440 server_task task(SERVER_TASK_TYPE_CANCEL);
441 task.id_target = id_task;
442 queue_results.remove_waiting_task_id(id_task);
443 cancel_tasks.push_back(std::move(task));
444 }
445 // push to beginning of the queue, so it has highest priority
446 queue_tasks.post(std::move(cancel_tasks), true);
447 } else {
448 SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
449 }
450}