1#include "json-partial.h"
  2
  3#include "log.h"
  4
  5#include <nlohmann/json.hpp>
  6
  7#include <string>
  8#include <regex>
  9
 10using json = nlohmann::ordered_json;
 11
 12enum common_json_stack_element_type {
 13    COMMON_JSON_STACK_ELEMENT_OBJECT,
 14    COMMON_JSON_STACK_ELEMENT_KEY,
 15    COMMON_JSON_STACK_ELEMENT_ARRAY,
 16};
 17
 18struct common_json_stack_element {
 19    common_json_stack_element_type type;
 20    std::string key;
 21};
 22
 23bool common_json_parse(
 24    const std::string & input,
 25    const std::string & healing_marker,
 26    common_json & out)
 27{
 28    std::string::const_iterator it = input.begin();
 29    const auto end = input.end();
 30    return common_json_parse(it, end, healing_marker, out);
 31}
 32
 33bool common_json_parse(
 34    std::string::const_iterator & it,
 35    const std::string::const_iterator & end,
 36    const std::string & healing_marker,
 37    common_json & out)
 38{
 39    // // https://json.nlohmann.me/features/parsing/sax_interface/
 40    struct json_error_locator : public nlohmann::json_sax<json> {
 41        std::size_t position;
 42        bool found_error;
 43        std::string last_token;
 44        std::string exception_message;
 45        std::vector<common_json_stack_element> stack;
 46
 47        json_error_locator() : position(0), found_error(false) {}
 48
 49        bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT
 50            this->position = position - 1;
 51            this->found_error = true;
 52            this->last_token = last_token;
 53            this->exception_message = ex.what();
 54            return false;
 55        }
 56        void close_value() {
 57            if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) {
 58                stack.pop_back();
 59            }
 60        }
 61        bool null() override { // NOLINT
 62            close_value();
 63            return true;
 64        }
 65        bool boolean(bool) override { // NOLINT
 66            close_value();
 67            return true;
 68        }
 69        bool number_integer(number_integer_t) override { // NOLINT
 70            close_value();
 71            return true;
 72        }
 73        bool number_unsigned(number_unsigned_t) override { // NOLINT
 74            close_value();
 75            return true;
 76        }
 77        bool number_float(number_float_t, const string_t &) override { // NOLINT
 78            close_value();
 79            return true;
 80        }
 81        bool string(string_t &) override { // NOLINT
 82            close_value();
 83            return true;
 84        }
 85        bool binary(binary_t &) override { // NOLINT
 86            close_value();
 87            return true;
 88        }
 89        bool start_object(std::size_t) override { // NOLINT
 90            stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""});
 91            return true;
 92        }
 93        bool end_object() override {
 94            GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT);
 95            stack.pop_back();
 96            close_value();
 97            return true;
 98        }
 99        bool key(string_t & key) override { // NOLINT
100            stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key});
101            return true;
102        }
103        bool start_array(std::size_t) override { // NOLINT
104            stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""});
105            return true;
106        }
107        bool end_array() override {
108            GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY);
109            stack.pop_back();
110            close_value();
111            return true;
112        }
113    };
114    json_error_locator err_loc;
115    auto start = it;
116    json::sax_parse(it, end, &err_loc);
117
118    if (err_loc.found_error) {
119        it = start;
120        auto temptative_end = it + err_loc.position;
121        // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str());
122
123        auto input = std::string(it, temptative_end);
124        try {
125            out.json = json::parse(input);
126            // out.json = json::parse(it, temptative_end);
127            it = temptative_end;
128            return true;
129        } catch (const std::exception & ex) {
130            // No, needs healing.
131            LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str());
132        }
133        auto can_parse = [](const std::string & str) {
134            try {
135                auto _ = json::parse(str); // NOLINT
136                return true;
137            } catch (const std::exception &) {
138                return false;
139            }
140        };
141        if (!healing_marker.empty() && !err_loc.stack.empty()) {
142            std::string str(it, temptative_end);
143            auto last_non_sp_pos = str.find_last_not_of(" \n\r\t");
144            if (last_non_sp_pos == std::string::npos) {
145                throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
146            }
147            auto last_non_sp_char = str[last_non_sp_pos];
148            // Used to detect stops on a number, which may not be complete.
149            auto was_maybe_number = [&]() {
150                if (!str.empty() && std::isspace(str.back())) {
151                    return false;
152                }
153                return std::isdigit(last_non_sp_char) ||
154                    last_non_sp_char == '.' ||
155                    last_non_sp_char == 'e' ||
156                    last_non_sp_char == 'E' ||
157                    last_non_sp_char == '-';
158            };
159
160            std::string closing;
161            for (size_t i = err_loc.stack.size(); i > 0; i--) {
162                auto & el = err_loc.stack[i - 1];
163                if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
164                    closing += "}";
165                } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
166                    closing += "]";
167                } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) {
168                    throw std::runtime_error("Unexpected stack element type");
169                }
170            }
171
172            // Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX
173            static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)");
174
175            auto is_high_surrogate = [&](const std::string & s) {
176                // Check if a partial of a high surrogate (U+D800-U+DBFF)
177                return s.length() >= 4 &&
178                    s[0] == '\\' && s[1] == 'u' &&
179                    std::tolower(s[2]) == 'd' &&
180                    (s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b');
181            };
182
183            // Initialize the unicode marker to a low surrogate to handle the edge case
184            // where a high surrogate (U+D800-U+DBFF) is immediately followed by a
185            // backslash (\)
186            std::string unicode_marker_padding = "udc00";
187            std::smatch last_unicode_seq;
188
189            if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) {
190                std::smatch second_last_seq;
191                std::string prelude = str.substr(0, last_unicode_seq.position());
192
193                // Pad the escape sequence with 0s until it forms a complete sequence of 6 characters
194                unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0');
195
196                if (is_high_surrogate(last_unicode_seq.str())) {
197                    // If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF)
198                    unicode_marker_padding += "\\udc00";
199                } else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) {
200                    if (is_high_surrogate(second_last_seq.str())) {
201                        // If this follows a high surrogate, pad it to be a low surrogate
202                        if (last_unicode_seq.length() == 2) {
203                            unicode_marker_padding = "dc00";
204                        } else if (last_unicode_seq.length() == 3) {
205                            unicode_marker_padding = "c00";
206                        } else {
207                            // The original unicode_marker_padding is already padded with 0s
208                        }
209                    }
210                }
211            }
212
213            const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$";
214
215            if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) {
216                // We're inside an object value
217                if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) {
218                    // Was about to create an object value
219                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
220                } else if (can_parse(str + ": 1" + closing)) {
221                    str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing;
222                } else if (last_non_sp_char == '{' && can_parse(str + closing)) {
223                    // Was about to create an object
224                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
225                } else if (can_parse(str + "\"" + closing)) {
226                    // Was inside an object value string
227                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
228                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
229                    // Was inside an object value string after an escape
230                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
231                } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
232                    // Was inside an object value string after a partial unicode escape
233                    str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
234                } else {
235                    // find last :
236                    auto last_pos = str.find_last_of(':');
237                    if (last_pos == std::string::npos) {
238                        throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location");
239                    }
240                    // Cutting back to opening : for object value
241                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
242                }
243            } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) {
244                if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) {
245                    // Was about to create an array value
246                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
247                } else if (can_parse(str + "\"" + closing)) {
248                    // Was inside an array value string
249                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing;
250                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) {
251                    // Was inside an array value string after an escape
252                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing;
253                } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) {
254                    // Was inside an array value string after a partial unicode escape
255                    str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing;
256                } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) {
257                    // Had just finished a value
258                    str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing;
259                } else {
260                    auto last_pos = str.find_last_of("[,");
261                    if (last_pos == std::string::npos) {
262                        throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location");
263                    }
264                    // Cutting back to last [ or , for array value
265                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
266                }
267            } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) {
268                if ((last_non_sp_char == '{' && can_parse(str + closing)) ||
269                        (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) {
270                    // Was about to create an object key+value
271                    str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing;
272                } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) {
273                    // Was about to create an object key+value
274                    str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing;
275                } else if (can_parse(str + "\": 1" + closing)) {
276                    // Was inside an object key string
277                    str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing;
278                } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) {
279                    // Was inside an object key string after an escape
280                    str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing;
281                } else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) {
282                    // Was inside an object key string after a partial unicode escape
283                    str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing;
284                } else {
285                    auto last_pos = str.find_last_of(':');
286                    if (last_pos == std::string::npos) {
287                        throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
288                    }
289                    // fprintf(stderr, "Cutting back to last : for object key+value\n");
290                    str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing;
291                }
292            } else {
293                throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location");
294            }
295            // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str());
296            out.json = json::parse(str);
297            it = temptative_end;
298            return true;
299        }
300        // handle unclosed top-level primitive
301        if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) {
302            std::string str(it, temptative_end);
303            const auto & magic_seed = out.healing_marker.marker = healing_marker;
304            if (can_parse(str + "\"")) {
305                // Was inside an string
306                str += (out.healing_marker.json_dump_marker = magic_seed) + "\"";
307            } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) {
308                // Was inside an string after an escape
309                str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"";
310            } else {
311                // TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...)
312                // fprintf(stderr, "Closing: TODO\n");
313                return false;
314            }
315            out.json = json::parse(str);
316            it = temptative_end;
317            return true;
318        }
319        return false;
320    }
321    out.json = json::parse(it, end);
322    it = end;
323    return true;
324}