1#include "console.h"
   2#include "log.h"
   3#include <vector>
   4#include <iostream>
   5#include <cassert>
   6#include <cstddef>
   7#include <cctype>
   8#include <cwctype>
   9#include <cstdint>
  10#include <condition_variable>
  11#include <mutex>
  12#include <thread>
  13#include <stdarg.h>
  14
  15#if defined(_WIN32)
  16#define WIN32_LEAN_AND_MEAN
  17#ifndef NOMINMAX
  18#define NOMINMAX
  19#endif
  20#include <windows.h>
  21#include <fcntl.h>
  22#include <io.h>
  23#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING
  24#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004
  25#endif
  26#else
  27#include <climits>
  28#include <sys/ioctl.h>
  29#include <unistd.h>
  30#include <wchar.h>
  31#include <stdio.h>
  32#include <stdlib.h>
  33#include <signal.h>
  34#include <termios.h>
  35#endif
  36
  37#define ANSI_COLOR_RED     "\x1b[31m"
  38#define ANSI_COLOR_GREEN   "\x1b[32m"
  39#define ANSI_COLOR_YELLOW  "\x1b[33m"
  40#define ANSI_COLOR_BLUE    "\x1b[34m"
  41#define ANSI_COLOR_MAGENTA "\x1b[35m"
  42#define ANSI_COLOR_CYAN    "\x1b[36m"
  43#define ANSI_COLOR_GRAY    "\x1b[90m"
  44#define ANSI_COLOR_RESET   "\x1b[0m"
  45#define ANSI_BOLD          "\x1b[1m"
  46
  47namespace console {
  48
  49#if defined (_WIN32)
  50    namespace {
  51        // Use private-use unicode values to represent special keys that are not reported
  52        // as characters (e.g. arrows on Windows). These values should never clash with
  53        // real input and let the rest of the code handle navigation uniformly.
  54        static constexpr char32_t KEY_ARROW_LEFT       = 0xE000;
  55        static constexpr char32_t KEY_ARROW_RIGHT      = 0xE001;
  56        static constexpr char32_t KEY_ARROW_UP         = 0xE002;
  57        static constexpr char32_t KEY_ARROW_DOWN       = 0xE003;
  58        static constexpr char32_t KEY_HOME             = 0xE004;
  59        static constexpr char32_t KEY_END              = 0xE005;
  60        static constexpr char32_t KEY_CTRL_ARROW_LEFT  = 0xE006;
  61        static constexpr char32_t KEY_CTRL_ARROW_RIGHT = 0xE007;
  62        static constexpr char32_t KEY_DELETE           = 0xE008;
  63    }
  64
  65    //
  66    // Console state
  67    //
  68#endif
  69
  70    static bool         advanced_display = false;
  71    static bool         simple_io        = true;
  72    static display_type current_display  = DISPLAY_TYPE_RESET;
  73
  74    static FILE*        out              = stdout;
  75
  76#if defined (_WIN32)
  77    static void*        hConsole;
  78#else
  79    static FILE*        tty              = nullptr;
  80    static termios      initial_state;
  81#endif
  82
  83    //
  84    // Init and cleanup
  85    //
  86
  87    void init(bool use_simple_io, bool use_advanced_display) {
  88        advanced_display = use_advanced_display;
  89        simple_io = use_simple_io;
  90#if defined(_WIN32)
  91        // Windows-specific console initialization
  92        DWORD dwMode = 0;
  93        hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
  94        if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
  95            hConsole = GetStdHandle(STD_ERROR_HANDLE);
  96            if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) {
  97                hConsole = nullptr;
  98                simple_io = true;
  99            }
 100        }
 101        if (hConsole) {
 102            // Check conditions combined to reduce nesting
 103            if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) &&
 104                !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
 105                advanced_display = false;
 106            }
 107            // Set console output codepage to UTF8
 108            SetConsoleOutputCP(CP_UTF8);
 109        }
 110        HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
 111        if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
 112            // Set console input codepage to UTF16
 113            _setmode(_fileno(stdin), _O_WTEXT);
 114
 115            // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
 116            if (simple_io) {
 117                dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
 118            } else {
 119                dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
 120            }
 121            if (!SetConsoleMode(hConIn, dwMode)) {
 122                simple_io = true;
 123            }
 124        }
 125        if (simple_io) {
 126            _setmode(_fileno(stdin), _O_U8TEXT);
 127        }
 128#else
 129        // POSIX-specific console initialization
 130        if (!simple_io) {
 131            struct termios new_termios;
 132            tcgetattr(STDIN_FILENO, &initial_state);
 133            new_termios = initial_state;
 134            new_termios.c_lflag &= ~(ICANON | ECHO);
 135            new_termios.c_cc[VMIN] = 1;
 136            new_termios.c_cc[VTIME] = 0;
 137            tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
 138
 139            tty = fopen("/dev/tty", "w+");
 140            if (tty != nullptr) {
 141                out = tty;
 142            }
 143        }
 144
 145        setlocale(LC_ALL, "");
 146#endif
 147    }
 148
 149    void cleanup() {
 150        // Reset console display
 151        set_display(DISPLAY_TYPE_RESET);
 152
 153#if !defined(_WIN32)
 154        // Restore settings on POSIX systems
 155        if (!simple_io) {
 156            if (tty != nullptr) {
 157                out = stdout;
 158                fclose(tty);
 159                tty = nullptr;
 160            }
 161            tcsetattr(STDIN_FILENO, TCSANOW, &initial_state);
 162        }
 163#endif
 164    }
 165
 166    //
 167    // Display and IO
 168    //
 169
 170    // Keep track of current display and only emit ANSI code if it changes
 171    void set_display(display_type display) {
 172        if (advanced_display && current_display != display) {
 173            common_log_flush(common_log_main());
 174            switch(display) {
 175                case DISPLAY_TYPE_RESET:
 176                    fprintf(out, ANSI_COLOR_RESET);
 177                    break;
 178                case DISPLAY_TYPE_INFO:
 179                    fprintf(out, ANSI_COLOR_MAGENTA);
 180                    break;
 181                case DISPLAY_TYPE_PROMPT:
 182                    fprintf(out, ANSI_COLOR_YELLOW);
 183                    break;
 184                case DISPLAY_TYPE_REASONING:
 185                    fprintf(out, ANSI_COLOR_GRAY);
 186                    break;
 187                case DISPLAY_TYPE_USER_INPUT:
 188                    fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
 189                    break;
 190                case DISPLAY_TYPE_ERROR:
 191                    fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
 192            }
 193            current_display = display;
 194            fflush(out);
 195        }
 196    }
 197
 198    static char32_t getchar32() {
 199#if defined(_WIN32)
 200        HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE);
 201        wchar_t high_surrogate = 0;
 202
 203        while (true) {
 204            INPUT_RECORD record;
 205            DWORD count;
 206            if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) {
 207                return WEOF;
 208            }
 209
 210            if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
 211                wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
 212                if (wc == 0) {
 213                    const DWORD ctrl_mask = LEFT_CTRL_PRESSED | RIGHT_CTRL_PRESSED;
 214                    const bool ctrl_pressed = (record.Event.KeyEvent.dwControlKeyState & ctrl_mask) != 0;
 215                    switch (record.Event.KeyEvent.wVirtualKeyCode) {
 216                        case VK_LEFT:   return ctrl_pressed ? KEY_CTRL_ARROW_LEFT  : KEY_ARROW_LEFT;
 217                        case VK_RIGHT:  return ctrl_pressed ? KEY_CTRL_ARROW_RIGHT : KEY_ARROW_RIGHT;
 218                        case VK_UP:     return KEY_ARROW_UP;
 219                        case VK_DOWN:   return KEY_ARROW_DOWN;
 220                        case VK_HOME:   return KEY_HOME;
 221                        case VK_END:    return KEY_END;
 222                        case VK_DELETE: return KEY_DELETE;
 223                        default:        continue;
 224                    }
 225                }
 226
 227                if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
 228                    high_surrogate = wc;
 229                    continue;
 230                }
 231                if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate
 232                    if (high_surrogate != 0) { // Check if we have a high surrogate
 233                        return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000;
 234                    }
 235                }
 236
 237                high_surrogate = 0; // Reset the high surrogate
 238                return static_cast<char32_t>(wc);
 239            }
 240        }
 241#else
 242        wchar_t wc = getwchar();
 243        if (static_cast<wint_t>(wc) == WEOF) {
 244            return WEOF;
 245        }
 246
 247#if WCHAR_MAX == 0xFFFF
 248        if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
 249            wchar_t low_surrogate = getwchar();
 250            if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
 251                return (static_cast<char32_t>(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
 252            }
 253        }
 254        if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
 255            return 0xFFFD; // Return the replacement character U+FFFD
 256        }
 257#endif
 258
 259        return static_cast<char32_t>(wc);
 260#endif
 261    }
 262
 263    static void pop_cursor() {
 264#if defined(_WIN32)
 265        if (hConsole != NULL) {
 266            CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
 267            GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
 268
 269            COORD newCursorPosition = bufferInfo.dwCursorPosition;
 270            if (newCursorPosition.X == 0) {
 271                newCursorPosition.X = bufferInfo.dwSize.X - 1;
 272                newCursorPosition.Y -= 1;
 273            } else {
 274                newCursorPosition.X -= 1;
 275            }
 276
 277            SetConsoleCursorPosition(hConsole, newCursorPosition);
 278            return;
 279        }
 280#endif
 281        putc('\b', out);
 282    }
 283
 284    static int estimateWidth(char32_t codepoint) {
 285#if defined(_WIN32)
 286        (void)codepoint;
 287        return 1;
 288#else
 289        return wcwidth(codepoint);
 290#endif
 291    }
 292
 293    static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) {
 294#if defined(_WIN32)
 295        CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
 296        if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
 297            // go with the default
 298            return expectedWidth;
 299        }
 300        COORD initialPosition = bufferInfo.dwCursorPosition;
 301        DWORD nNumberOfChars = length;
 302        WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
 303
 304        CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
 305        GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
 306
 307        // Figure out our real position if we're in the last column
 308        if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
 309            DWORD nNumberOfChars;
 310            WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL);
 311            GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
 312        }
 313
 314        int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
 315        if (width < 0) {
 316            width += newBufferInfo.dwSize.X;
 317        }
 318        return width;
 319#else
 320        // We can trust expectedWidth if we've got one
 321        if (expectedWidth >= 0 || tty == nullptr) {
 322            fwrite(utf8_codepoint, length, 1, out);
 323            return expectedWidth;
 324        }
 325
 326        fputs("\033[6n", tty); // Query cursor position
 327        int x1;
 328        int y1;
 329        int x2;
 330        int y2;
 331        int results = 0;
 332        results = fscanf(tty, "\033[%d;%dR", &y1, &x1);
 333
 334        fwrite(utf8_codepoint, length, 1, tty);
 335
 336        fputs("\033[6n", tty); // Query cursor position
 337        results += fscanf(tty, "\033[%d;%dR", &y2, &x2);
 338
 339        if (results != 4) {
 340            return expectedWidth;
 341        }
 342
 343        int width = x2 - x1;
 344        if (width < 0) {
 345            // Calculate the width considering text wrapping
 346            struct winsize w;
 347            ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
 348            width += w.ws_col;
 349        }
 350        return width;
 351#endif
 352    }
 353
 354    static void replace_last(char ch) {
 355#if defined(_WIN32)
 356        pop_cursor();
 357        put_codepoint(&ch, 1, 1);
 358#else
 359        fprintf(out, "\b%c", ch);
 360#endif
 361    }
 362
 363    static char32_t decode_utf8(const std::string & input, size_t pos, size_t & advance) {
 364        unsigned char c = static_cast<unsigned char>(input[pos]);
 365        if ((c & 0x80u) == 0u) {
 366            advance = 1;
 367            return c;
 368        }
 369        if ((c & 0xE0u) == 0xC0u && pos + 1 < input.size()) {
 370            unsigned char c1 = static_cast<unsigned char>(input[pos + 1]);
 371            if ((c1 & 0xC0u) != 0x80u) {
 372                advance = 1;
 373                return 0xFFFD;
 374            }
 375            advance = 2;
 376            return ((c & 0x1Fu) << 6) | (static_cast<unsigned char>(input[pos + 1]) & 0x3Fu);
 377        }
 378        if ((c & 0xF0u) == 0xE0u && pos + 2 < input.size()) {
 379            unsigned char c1 = static_cast<unsigned char>(input[pos + 1]);
 380            unsigned char c2 = static_cast<unsigned char>(input[pos + 2]);
 381            if ((c1 & 0xC0u) != 0x80u || (c2 & 0xC0u) != 0x80u) {
 382                advance = 1;
 383                return 0xFFFD;
 384            }
 385            advance = 3;
 386            return ((c & 0x0Fu) << 12) |
 387                   ((static_cast<unsigned char>(input[pos + 1]) & 0x3Fu) << 6) |
 388                   (static_cast<unsigned char>(input[pos + 2]) & 0x3Fu);
 389        }
 390        if ((c & 0xF8u) == 0xF0u && pos + 3 < input.size()) {
 391            unsigned char c1 = static_cast<unsigned char>(input[pos + 1]);
 392            unsigned char c2 = static_cast<unsigned char>(input[pos + 2]);
 393            unsigned char c3 = static_cast<unsigned char>(input[pos + 3]);
 394            if ((c1 & 0xC0u) != 0x80u || (c2 & 0xC0u) != 0x80u || (c3 & 0xC0u) != 0x80u) {
 395                advance = 1;
 396                return 0xFFFD;
 397            }
 398            advance = 4;
 399            return ((c & 0x07u) << 18) |
 400                   ((static_cast<unsigned char>(input[pos + 1]) & 0x3Fu) << 12) |
 401                   ((static_cast<unsigned char>(input[pos + 2]) & 0x3Fu) << 6) |
 402                   (static_cast<unsigned char>(input[pos + 3]) & 0x3Fu);
 403        }
 404
 405        advance = 1;
 406        return 0xFFFD; // replacement character for invalid input
 407    }
 408
 409    static void append_utf8(char32_t ch, std::string & out) {
 410        if (ch <= 0x7F) {
 411            out.push_back(static_cast<unsigned char>(ch));
 412        } else if (ch <= 0x7FF) {
 413            out.push_back(static_cast<unsigned char>(0xC0 | ((ch >> 6) & 0x1F)));
 414            out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
 415        } else if (ch <= 0xFFFF) {
 416            out.push_back(static_cast<unsigned char>(0xE0 | ((ch >> 12) & 0x0F)));
 417            out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
 418            out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
 419        } else if (ch <= 0x10FFFF) {
 420            out.push_back(static_cast<unsigned char>(0xF0 | ((ch >> 18) & 0x07)));
 421            out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 12) & 0x3F)));
 422            out.push_back(static_cast<unsigned char>(0x80 | ((ch >> 6) & 0x3F)));
 423            out.push_back(static_cast<unsigned char>(0x80 | (ch & 0x3F)));
 424        } else {
 425            // Invalid Unicode code point
 426        }
 427    }
 428
 429    // Helper function to remove the last UTF-8 character from a string
 430    static size_t prev_utf8_char_pos(const std::string & line, size_t pos) {
 431        if (pos == 0) return 0;
 432        pos--;
 433        while (pos > 0 && (line[pos] & 0xC0) == 0x80) {
 434            pos--;
 435        }
 436        return pos;
 437    }
 438
 439    static size_t next_utf8_char_pos(const std::string & line, size_t pos) {
 440        if (pos >= line.length()) return line.length();
 441        pos++;
 442        while (pos < line.length() && (line[pos] & 0xC0) == 0x80) {
 443            pos++;
 444        }
 445        return pos;
 446    }
 447
 448    static void move_cursor(int delta);
 449    static void move_word_left(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line);
 450    static void move_word_right(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line);
 451    static void move_to_line_start(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths);
 452    static void move_to_line_end(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line);
 453
 454    static void delete_at_cursor(std::string & line, std::vector<int> & widths, size_t & char_pos, size_t & byte_pos) {
 455        if (char_pos >= widths.size()) {
 456            return;
 457        }
 458
 459        size_t next_pos = next_utf8_char_pos(line, byte_pos);
 460        int w = widths[char_pos];
 461        size_t char_len = next_pos - byte_pos;
 462
 463        line.erase(byte_pos, char_len);
 464        widths.erase(widths.begin() + char_pos);
 465
 466        size_t p = byte_pos;
 467        int tail_width = 0;
 468        for (size_t i = char_pos; i < widths.size(); ++i) {
 469            size_t following = next_utf8_char_pos(line, p);
 470            put_codepoint(line.c_str() + p, following - p, widths[i]);
 471            tail_width += widths[i];
 472            p = following;
 473        }
 474
 475        for (int i = 0; i < w; ++i) {
 476            fputc(' ', out);
 477        }
 478
 479        move_cursor(-(tail_width + w));
 480    }
 481
 482    static void clear_current_line(const std::vector<int> & widths) {
 483        int total_width = 0;
 484        for (int w : widths) {
 485            total_width += (w > 0 ? w : 1);
 486        }
 487
 488        if (total_width > 0) {
 489            std::string spaces(total_width, ' ');
 490            fwrite(spaces.c_str(), 1, total_width, out);
 491            move_cursor(-total_width);
 492        }
 493    }
 494
 495    static void set_line_contents(std::string new_line, std::string & line, std::vector<int> & widths, size_t & char_pos,
 496                                  size_t & byte_pos) {
 497        move_to_line_start(char_pos, byte_pos, widths);
 498        clear_current_line(widths);
 499
 500        line = std::move(new_line);
 501        widths.clear();
 502        byte_pos = 0;
 503        char_pos = 0;
 504
 505        size_t idx = 0;
 506        while (idx < line.size()) {
 507            size_t advance = 0;
 508            char32_t cp = decode_utf8(line, idx, advance);
 509            int expected_width = estimateWidth(cp);
 510            int real_width = put_codepoint(line.c_str() + idx, advance, expected_width);
 511            if (real_width < 0) real_width = 0;
 512            widths.push_back(real_width);
 513            idx += advance;
 514            ++char_pos;
 515            byte_pos = idx;
 516        }
 517    }
 518
 519    static void move_to_line_start(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths) {
 520        int back_width = 0;
 521        for (size_t i = 0; i < char_pos; ++i) {
 522            back_width += widths[i];
 523        }
 524        move_cursor(-back_width);
 525        char_pos = 0;
 526        byte_pos = 0;
 527    }
 528
 529    static void move_to_line_end(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line) {
 530        int forward_width = 0;
 531        for (size_t i = char_pos; i < widths.size(); ++i) {
 532            forward_width += widths[i];
 533        }
 534        move_cursor(forward_width);
 535        char_pos = widths.size();
 536        byte_pos = line.length();
 537    }
 538
 539    static bool has_ctrl_modifier(const std::string & params) {
 540        size_t start = 0;
 541        while (start < params.size()) {
 542            size_t end = params.find(';', start);
 543            size_t len = (end == std::string::npos) ? params.size() - start : end - start;
 544            if (len > 0) {
 545                int value = 0;
 546                for (size_t i = 0; i < len; ++i) {
 547                    char ch = params[start + i];
 548                    if (!std::isdigit(static_cast<unsigned char>(ch))) {
 549                        value = -1;
 550                        break;
 551                    }
 552                    value = value * 10 + (ch - '0');
 553                }
 554                if (value == 5) {
 555                    return true;
 556                }
 557            }
 558
 559            if (end == std::string::npos) {
 560                break;
 561            }
 562            start = end + 1;
 563        }
 564        return false;
 565    }
 566
 567    static bool is_space_codepoint(char32_t cp) {
 568        return std::iswspace(static_cast<wint_t>(cp)) != 0;
 569    }
 570
 571    static void move_word_left(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line) {
 572        if (char_pos == 0) {
 573            return;
 574        }
 575
 576        size_t new_char_pos = char_pos;
 577        size_t new_byte_pos = byte_pos;
 578        int move_width = 0;
 579
 580        while (new_char_pos > 0) {
 581            size_t prev_byte = prev_utf8_char_pos(line, new_byte_pos);
 582            size_t advance = 0;
 583            char32_t cp = decode_utf8(line, prev_byte, advance);
 584            if (!is_space_codepoint(cp)) {
 585                break;
 586            }
 587            move_width += widths[new_char_pos - 1];
 588            new_char_pos--;
 589            new_byte_pos = prev_byte;
 590        }
 591
 592        while (new_char_pos > 0) {
 593            size_t prev_byte = prev_utf8_char_pos(line, new_byte_pos);
 594            size_t advance = 0;
 595            char32_t cp = decode_utf8(line, prev_byte, advance);
 596            if (is_space_codepoint(cp)) {
 597                break;
 598            }
 599            move_width += widths[new_char_pos - 1];
 600            new_char_pos--;
 601            new_byte_pos = prev_byte;
 602        }
 603
 604        move_cursor(-move_width);
 605        char_pos = new_char_pos;
 606        byte_pos = new_byte_pos;
 607    }
 608
 609    static void move_word_right(size_t & char_pos, size_t & byte_pos, const std::vector<int> & widths, const std::string & line) {
 610        if (char_pos >= widths.size()) {
 611            return;
 612        }
 613
 614        size_t new_char_pos = char_pos;
 615        size_t new_byte_pos = byte_pos;
 616        int move_width = 0;
 617
 618        while (new_char_pos < widths.size()) {
 619            size_t advance = 0;
 620            char32_t cp = decode_utf8(line, new_byte_pos, advance);
 621            if (!is_space_codepoint(cp)) {
 622                break;
 623            }
 624            move_width += widths[new_char_pos];
 625            new_char_pos++;
 626            new_byte_pos += advance;
 627        }
 628
 629        while (new_char_pos < widths.size()) {
 630            size_t advance = 0;
 631            char32_t cp = decode_utf8(line, new_byte_pos, advance);
 632            if (is_space_codepoint(cp)) {
 633                break;
 634            }
 635            move_width += widths[new_char_pos];
 636            new_char_pos++;
 637            new_byte_pos += advance;
 638        }
 639
 640        while (new_char_pos < widths.size()) {
 641            size_t advance = 0;
 642            char32_t cp = decode_utf8(line, new_byte_pos, advance);
 643            if (!is_space_codepoint(cp)) {
 644                break;
 645            }
 646            move_width += widths[new_char_pos];
 647            new_char_pos++;
 648            new_byte_pos += advance;
 649        }
 650
 651        move_cursor(move_width);
 652        char_pos = new_char_pos;
 653        byte_pos = new_byte_pos;
 654    }
 655
 656    static void move_cursor(int delta) {
 657        if (delta == 0) return;
 658#if defined(_WIN32)
 659        if (hConsole != NULL) {
 660            CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
 661            GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
 662            COORD newCursorPosition = bufferInfo.dwCursorPosition;
 663            int width = bufferInfo.dwSize.X;
 664            int newX = newCursorPosition.X + delta;
 665            int newY = newCursorPosition.Y;
 666
 667            while (newX >= width) {
 668                newX -= width;
 669                newY++;
 670            }
 671            while (newX < 0) {
 672                newX += width;
 673                newY--;
 674            }
 675
 676            newCursorPosition.X = newX;
 677            newCursorPosition.Y = newY;
 678            SetConsoleCursorPosition(hConsole, newCursorPosition);
 679        }
 680#else
 681        if (delta < 0) {
 682            for (int i = 0; i < -delta; i++) fprintf(out, "\b");
 683        } else {
 684            for (int i = 0; i < delta; i++) fprintf(out, "\033[C");
 685        }
 686#endif
 687    }
 688
 689    struct history_t {
 690        std::vector<std::string> entries;
 691        size_t viewing_idx = SIZE_MAX;
 692        std::string backup_line; // current line before viewing history
 693        void add(const std::string & line) {
 694            if (line.empty()) {
 695                return;
 696            }
 697            // avoid duplicates with the last entry
 698            if (entries.empty() || entries.back() != line) {
 699                entries.push_back(line);
 700            }
 701            // also clear viewing state
 702            end_viewing();
 703        }
 704        bool prev(std::string & cur_line) {
 705            if (entries.empty()) {
 706                return false;
 707            }
 708            if (viewing_idx == SIZE_MAX) {
 709                return false;
 710            }
 711            if (viewing_idx > 0) {
 712                viewing_idx--;
 713            }
 714            cur_line = entries[viewing_idx];
 715            return true;
 716        }
 717        bool next(std::string & cur_line) {
 718            if (entries.empty() || viewing_idx == SIZE_MAX) {
 719                return false;
 720            }
 721            viewing_idx++;
 722            if (viewing_idx >= entries.size()) {
 723                cur_line = backup_line;
 724                end_viewing();
 725            } else {
 726                cur_line = entries[viewing_idx];
 727            }
 728            return true;
 729        }
 730        void begin_viewing(const std::string & line) {
 731            backup_line = line;
 732            viewing_idx = entries.size();
 733        }
 734        void end_viewing() {
 735            viewing_idx = SIZE_MAX;
 736            backup_line.clear();
 737        }
 738        bool is_viewing() const {
 739            return viewing_idx != SIZE_MAX;
 740        }
 741    } history;
 742
 743    static bool readline_advanced(std::string & line, bool multiline_input) {
 744        if (out != stdout) {
 745            fflush(stdout);
 746        }
 747
 748        line.clear();
 749        std::vector<int> widths;
 750        bool is_special_char = false;
 751        bool end_of_stream = false;
 752
 753        size_t byte_pos = 0; // current byte index
 754        size_t char_pos = 0; // current character index (one char can be multiple bytes)
 755
 756        char32_t input_char;
 757        while (true) {
 758            assert(char_pos <= byte_pos);
 759            assert(char_pos <= widths.size());
 760            auto history_prev = [&]() {
 761                if (!history.is_viewing()) {
 762                    history.begin_viewing(line);
 763                }
 764                std::string new_line;
 765                if (!history.prev(new_line)) {
 766                    return;
 767                }
 768                set_line_contents(new_line, line, widths, char_pos, byte_pos);
 769            };
 770            auto history_next = [&]() {
 771                if (history.is_viewing()) {
 772                    std::string new_line;
 773                    if (!history.next(new_line)) {
 774                        return;
 775                    }
 776                    set_line_contents(new_line, line, widths, char_pos, byte_pos);
 777                }
 778            };
 779
 780            fflush(out); // Ensure all output is displayed before waiting for input
 781            input_char = getchar32();
 782
 783            if (input_char == '\r' || input_char == '\n') {
 784                break;
 785            }
 786
 787            if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D */) {
 788                end_of_stream = true;
 789                break;
 790            }
 791
 792            if (is_special_char) {
 793                replace_last(line.back());
 794                is_special_char = false;
 795            }
 796
 797            if (input_char == '\033') { // Escape sequence
 798                char32_t code = getchar32();
 799                if (code == '[') {
 800                    std::string params;
 801                    while (true) {
 802                        code = getchar32();
 803                        if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~' || code == (char32_t) WEOF) {
 804                            break;
 805                        }
 806                        params.push_back(static_cast<char>(code));
 807                    }
 808
 809                    const bool ctrl_modifier = has_ctrl_modifier(params);
 810
 811                    if (code == 'D') { // left
 812                        if (ctrl_modifier) {
 813                            move_word_left(char_pos, byte_pos, widths, line);
 814                        } else if (char_pos > 0) {
 815                            int w = widths[char_pos - 1];
 816                            move_cursor(-w);
 817                            char_pos--;
 818                            byte_pos = prev_utf8_char_pos(line, byte_pos);
 819                        }
 820                    } else if (code == 'C') { // right
 821                        if (ctrl_modifier) {
 822                            move_word_right(char_pos, byte_pos, widths, line);
 823                        } else if (char_pos < widths.size()) {
 824                            int w = widths[char_pos];
 825                            move_cursor(w);
 826                            char_pos++;
 827                            byte_pos = next_utf8_char_pos(line, byte_pos);
 828                        }
 829                    } else if (code == 'H') { // home
 830                        move_to_line_start(char_pos, byte_pos, widths);
 831                    } else if (code == 'F') { // end
 832                        move_to_line_end(char_pos, byte_pos, widths, line);
 833                    } else if (code == 'A' || code == 'B') {
 834                        // up/down
 835                        if (code == 'A') {
 836                            history_prev();
 837                            is_special_char = false;
 838                        } else if (code == 'B') {
 839                            history_next();
 840                            is_special_char = false;
 841                        }
 842                    } else if ((code == '~' || (code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z')) && !params.empty()) {
 843                        std::string digits;
 844                        for (char ch : params) {
 845                            if (ch == ';') {
 846                                break;
 847                            }
 848                            if (std::isdigit(static_cast<unsigned char>(ch))) {
 849                                digits.push_back(ch);
 850                            }
 851                        }
 852
 853                        if (code == '~') {
 854                            if (digits == "1" || digits == "7") { // home
 855                                move_to_line_start(char_pos, byte_pos, widths);
 856                            } else if (digits == "4" || digits == "8") { // end
 857                                move_to_line_end(char_pos, byte_pos, widths, line);
 858                            } else if (digits == "3") { // delete
 859                                delete_at_cursor(line, widths, char_pos, byte_pos);
 860                            }
 861                        }
 862                    }
 863                } else if (code == 0x1B) {
 864                    // Discard the rest of the escape sequence
 865                    while ((code = getchar32()) != (char32_t) WEOF) {
 866                        if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
 867                            break;
 868                        }
 869                    }
 870                }
 871#if defined(_WIN32)
 872            } else if (input_char == KEY_ARROW_LEFT) {
 873                if (char_pos > 0) {
 874                    int w = widths[char_pos - 1];
 875                    move_cursor(-w);
 876                    char_pos--;
 877                    byte_pos = prev_utf8_char_pos(line, byte_pos);
 878                }
 879            } else if (input_char == KEY_ARROW_RIGHT) {
 880                if (char_pos < widths.size()) {
 881                    int w = widths[char_pos];
 882                    move_cursor(w);
 883                    char_pos++;
 884                    byte_pos = next_utf8_char_pos(line, byte_pos);
 885                }
 886            } else if (input_char == KEY_CTRL_ARROW_LEFT) {
 887                move_word_left(char_pos, byte_pos, widths, line);
 888            } else if (input_char == KEY_CTRL_ARROW_RIGHT) {
 889                move_word_right(char_pos, byte_pos, widths, line);
 890            } else if (input_char == KEY_HOME) {
 891                move_to_line_start(char_pos, byte_pos, widths);
 892            } else if (input_char == KEY_END) {
 893                move_to_line_end(char_pos, byte_pos, widths, line);
 894            } else if (input_char == KEY_DELETE) {
 895                delete_at_cursor(line, widths, char_pos, byte_pos);
 896            } else if (input_char == KEY_ARROW_UP || input_char == KEY_ARROW_DOWN) {
 897                if (input_char == KEY_ARROW_UP) {
 898                    history_prev();
 899                    is_special_char = false;
 900                } else if (input_char == KEY_ARROW_DOWN) {
 901                    history_next();
 902                    is_special_char = false;
 903                }
 904#endif
 905            } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
 906                if (char_pos > 0) {
 907                    int w = widths[char_pos - 1];
 908                    move_cursor(-w);
 909                    char_pos--;
 910                    size_t prev_pos = prev_utf8_char_pos(line, byte_pos);
 911                    size_t char_len = byte_pos - prev_pos;
 912                    byte_pos = prev_pos;
 913
 914                    // remove the character
 915                    line.erase(byte_pos, char_len);
 916                    widths.erase(widths.begin() + char_pos);
 917
 918                    // redraw tail
 919                    size_t p = byte_pos;
 920                    int tail_width = 0;
 921                    for (size_t i = char_pos; i < widths.size(); ++i) {
 922                        size_t next_p = next_utf8_char_pos(line, p);
 923                        put_codepoint(line.c_str() + p, next_p - p, widths[i]);
 924                        tail_width += widths[i];
 925                        p = next_p;
 926                    }
 927
 928                    // clear display
 929                    for (int i = 0; i < w; ++i) {
 930                        fputc(' ', out);
 931                    }
 932                    move_cursor(-(tail_width + w));
 933                }
 934            } else {
 935                // insert character
 936                std::string new_char_str;
 937                append_utf8(input_char, new_char_str);
 938                int w = estimateWidth(input_char);
 939
 940                if (char_pos == widths.size()) {
 941                    // insert at the end
 942                    line += new_char_str;
 943                    int real_w = put_codepoint(new_char_str.c_str(), new_char_str.length(), w);
 944                    if (real_w < 0) real_w = 0;
 945                    widths.push_back(real_w);
 946                    byte_pos += new_char_str.length();
 947                    char_pos++;
 948                } else {
 949                    // insert in middle
 950                    line.insert(byte_pos, new_char_str);
 951
 952                    int real_w = put_codepoint(new_char_str.c_str(), new_char_str.length(), w);
 953                    if (real_w < 0) real_w = 0;
 954
 955                    widths.insert(widths.begin() + char_pos, real_w);
 956
 957                    // print the tail
 958                    size_t p = byte_pos + new_char_str.length();
 959                    int tail_width = 0;
 960                    for (size_t i = char_pos + 1; i < widths.size(); ++i) {
 961                        size_t next_p = next_utf8_char_pos(line, p);
 962                        put_codepoint(line.c_str() + p, next_p - p, widths[i]);
 963                        tail_width += widths[i];
 964                        p = next_p;
 965                    }
 966
 967                    move_cursor(-tail_width);
 968
 969                    byte_pos += new_char_str.length();
 970                    char_pos++;
 971                }
 972            }
 973
 974            if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
 975                replace_last(line.back());
 976                is_special_char = true;
 977            }
 978        }
 979
 980        bool has_more = multiline_input;
 981        if (is_special_char) {
 982            replace_last(' ');
 983            pop_cursor();
 984
 985            char last = line.back();
 986            line.pop_back();
 987            if (last == '\\') {
 988                line += '\n';
 989                fputc('\n', out);
 990                has_more = !has_more;
 991            } else {
 992                // llama will just eat the single space, it won't act as a space
 993                if (line.length() == 1 && line.back() == ' ') {
 994                    line.clear();
 995                    pop_cursor();
 996                }
 997                has_more = false;
 998            }
 999        } else {
1000            if (end_of_stream) {
1001                has_more = false;
1002            } else {
1003                line += '\n';
1004                fputc('\n', out);
1005            }
1006        }
1007
1008        if (!end_of_stream && !line.empty()) {
1009            // remove the trailing newline for history storage
1010            if (!line.empty() && line.back() == '\n') {
1011                line.pop_back();
1012            }
1013            // TODO: maybe support multiline history entries?
1014            history.add(line);
1015        }
1016
1017        fflush(out);
1018        return has_more;
1019    }
1020
1021    static bool readline_simple(std::string & line, bool multiline_input) {
1022#if defined(_WIN32)
1023        std::wstring wline;
1024        if (!std::getline(std::wcin, wline)) {
1025            // Input stream is bad or EOF received
1026            line.clear();
1027            GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
1028            return false;
1029        }
1030
1031        int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
1032        line.resize(size_needed);
1033        WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
1034#else
1035        if (!std::getline(std::cin, line)) {
1036            // Input stream is bad or EOF received
1037            line.clear();
1038            return false;
1039        }
1040#endif
1041        if (!line.empty()) {
1042            char last = line.back();
1043            if (last == '/') { // Always return control on '/' symbol
1044                line.pop_back();
1045                return false;
1046            }
1047            if (last == '\\') { // '\\' changes the default action
1048                line.pop_back();
1049                multiline_input = !multiline_input;
1050            }
1051        }
1052        line += '\n';
1053
1054        // By default, continue input if multiline_input is set
1055        return multiline_input;
1056    }
1057
1058    bool readline(std::string & line, bool multiline_input) {
1059        if (simple_io) {
1060            return readline_simple(line, multiline_input);
1061        }
1062        return readline_advanced(line, multiline_input);
1063    }
1064
1065    namespace spinner {
1066        static const char LOADING_CHARS[] = {'|', '/', '-', '\\'};
1067        static std::condition_variable cv_stop;
1068        static std::thread th;
1069        static size_t frame = 0; // only modified by one thread
1070        static bool running = false;
1071        static std::mutex mtx;
1072        static auto wait_time = std::chrono::milliseconds(100);
1073        static void draw_next_frame() {
1074            // don't need lock because only one thread modifies running
1075            frame = (frame + 1) % sizeof(LOADING_CHARS);
1076            replace_last(LOADING_CHARS[frame]);
1077            fflush(out);
1078        }
1079        void start() {
1080            std::unique_lock<std::mutex> lock(mtx);
1081            if (simple_io || running) {
1082                return;
1083            }
1084            common_log_flush(common_log_main());
1085            fprintf(out, "%c", LOADING_CHARS[0]);
1086            fflush(out);
1087            frame = 1;
1088            running = true;
1089            th = std::thread([]() {
1090                std::unique_lock<std::mutex> lock(mtx);
1091                while (true) {
1092                    if (cv_stop.wait_for(lock, wait_time, []{ return !running; })) {
1093                        break;
1094                    }
1095                    draw_next_frame();
1096                }
1097            });
1098        }
1099        void stop() {
1100            {
1101                std::unique_lock<std::mutex> lock(mtx);
1102                if (simple_io || !running) {
1103                    return;
1104                }
1105                running = false;
1106                cv_stop.notify_all();
1107            }
1108            if (th.joinable()) {
1109                th.join();
1110            }
1111            replace_last(' ');
1112            pop_cursor();
1113            fflush(out);
1114        }
1115    }
1116
1117    void log(const char * fmt, ...) {
1118        va_list args;
1119        va_start(args, fmt);
1120        vfprintf(out, fmt, args);
1121        va_end(args);
1122    }
1123
1124    void error(const char * fmt, ...) {
1125        va_list args;
1126        va_start(args, fmt);
1127        display_type cur = current_display;
1128        set_display(DISPLAY_TYPE_ERROR);
1129        vfprintf(out, fmt, args);
1130        set_display(cur); // restore previous color
1131        va_end(args);
1132    }
1133
1134    void flush() {
1135        fflush(out);
1136    }
1137}