1//  Tests common_regex (esp. its partial final matches support).
  2
  3#include "common.h"
  4#include "regex-partial.h"
  5
  6#include <sstream>
  7#include <iostream>
  8#include <optional>
  9
 10template <class T> static void assert_equals(const T & expected, const T & actual) {
 11    if (expected != actual) {
 12        std::cerr << "Expected: " << expected << std::endl;
 13        std::cerr << "  Actual: " << actual << std::endl;
 14        std::cerr << std::flush;
 15        throw std::runtime_error("Test failed");
 16    }
 17}
 18
 19struct test_case {
 20    std::string pattern;
 21    struct input_output {
 22        std::string input;
 23        common_regex_match output;
 24    };
 25    std::vector<input_output> inputs_outputs;
 26};
 27
 28static std::string common_regex_match_type_name(common_regex_match_type type) {
 29    switch (type) {
 30        case COMMON_REGEX_MATCH_TYPE_NONE:
 31            return "COMMON_REGEX_MATCH_TYPE_NONE";
 32        case COMMON_REGEX_MATCH_TYPE_PARTIAL:
 33            return "COMMON_REGEX_MATCH_TYPE_PARTIAL";
 34        case COMMON_REGEX_MATCH_TYPE_FULL:
 35            return "COMMON_REGEX_MATCH_TYPE_FULL";
 36    }
 37    return "?";
 38}
 39
 40static void test_regex() {
 41    printf("[%s]\n", __func__);
 42    auto test = [](const test_case & test_case) {
 43        common_regex cr(test_case.pattern);
 44        std::cout << "Testing pattern: /" << test_case.pattern << "/\n";
 45        // std::cout << "    partial rev: " << cr.reversed_partial_pattern.str() << '\n';
 46        for (const auto & input_output : test_case.inputs_outputs) {
 47            std::cout << "  Input: " << input_output.input << '\n';
 48            auto m = cr.search(input_output.input, 0);
 49            if (m != input_output.output) {
 50                auto match_to_str = [&](const std::optional<common_regex_match> & m) {
 51                    std::ostringstream ss;
 52                    if (m->type == COMMON_REGEX_MATCH_TYPE_NONE) {
 53                        ss << "<no match>";
 54                    } else {
 55                        GGML_ASSERT(!input_output.output.groups.empty());
 56                        std::vector<std::string> parts;
 57                        for (const auto & g : m->groups) {
 58                            parts.push_back("{" + std::to_string(g.begin) + ", " + std::to_string(g.end) + "}");
 59                        }
 60                        ss << "{" << common_regex_match_type_name(m->type) << ", {" << string_join(parts, ", ") << "}}";
 61                    }
 62                    return ss.str();
 63                };
 64                std::cout << "    Expected: " << match_to_str(input_output.output) << '\n';
 65                std::cout << "         Got: " << match_to_str(m) << '\n';
 66                std::cout << " Inverted pattern: /" << regex_to_reversed_partial_regex(test_case.pattern) << "/\n";
 67
 68                throw std::runtime_error("Test failed");
 69            }
 70        }
 71    };
 72    test({
 73        "a",
 74        {
 75            {"a", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
 76            {"b", {COMMON_REGEX_MATCH_TYPE_NONE, {}}},
 77            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 1}}}},
 78            {"ba", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 2}}}},
 79        }
 80    });
 81    test({
 82        "abcd",
 83        {
 84            {"abcd", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
 85            {"abcde", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
 86            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
 87            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
 88            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
 89            {"d", {}},
 90            {"bcd", {}},
 91            {"cde", {}},
 92            {"cd", {}},
 93            {"yeah ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{5, 7}}}},
 94            {"abbie", {}},
 95            {"", {}},
 96        }
 97    });
 98    test({
 99        ".*?ab",
100        {
101            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
102            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
103            {"dab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
104            {"dabc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
105            {"da", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
106            {"d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
107        }
108    });
109    test({
110        "a.*?b",
111        {
112            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
113            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
114            {"a b", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
115            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
116            {"argh", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
117            {"d", {}},
118            {"b", {}},
119        }
120    });
121    test({
122        "ab(?:cd){2,4}ef",
123        {
124            // {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, 0, {}}},
125            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
126            {"abcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
127            {"abcde", {}},
128            {"abcdef", {}},
129            {"abcdcd", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
130            {"abcdcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 7}}}},
131            {"abcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
132            {"abcdcdcdcdef", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 12}}}},
133            {"abcdcdcdcdcdef", {}},
134            {"abcde", {}},
135            {"yea", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{2, 3}}}},
136        }
137    });
138    test({
139        "a(?:rte| pure )fact",
140        {
141            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
142            {"art", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
143            {"artefa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
144            {"fact", {}},
145            {"an arte", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{3, 7}}}},
146            {"artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}}}},
147            {"an artefact", {COMMON_REGEX_MATCH_TYPE_FULL, {{3, 11}}}},
148            {"a pure", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
149            {"a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 11}}}},
150            {"it's a pure fact", {COMMON_REGEX_MATCH_TYPE_FULL, {{5, 16}}}},
151            {"" , {}},
152            {"pure", {}},
153            {"pure fact", {}},
154        }
155    });
156    test({
157        "abc",
158        {
159            {" abcc", {COMMON_REGEX_MATCH_TYPE_FULL, {{1, 4}}}},
160            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
161            {"abc", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
162            {" ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{1, 3}}}},
163            {"a", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 1}}}},
164            {"b", {}},
165            {"c", {}},
166            {"", {}},
167        }
168    });
169
170    test({
171        "(?:abc)?\\s*def",
172        {
173            {"ab", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
174            {"abc", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
175            {"abc ", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 4}}}},
176            {"abc d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
177            {"abc de", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
178            {"abc def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
179            {"abc defg", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
180            {"abc defgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 7}}}},
181            {"abcde", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 5}}}},
182            {"abcdefgh", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 6}}}},
183            {" d", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 2}}}},
184            {"def", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 3}}}},
185        }
186    });
187
188    test({
189        "a+b",
190        {
191            {"aaab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 4}}}},
192            {"aaa", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 3}}}},
193            {"ab", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 2}}}},
194        }
195    });
196
197    test({
198        "(?:"
199            "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
200            "("                          // match 2 (open_tag)
201                "<tool_call>"
202                "|<function_call>"
203                "|<tool>"
204                "|<tools>"
205                "|<response>"
206                "|<json>"
207                "|<xml>"
208                "|<JSON>"
209            ")?"
210            "(\\s*\\{\\s*\"name\"\\s*:)" // match 3 (named tool call)
211        ")"
212        "|<function=([^>]+)>"            // match 4 (function name)
213        "|<function name=\"([^\"]+)\">", // match 5 (function name again)
214        {
215            {"{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 8}, {54, 54}, {54, 54}, {0, 8}, {54, 54}, {54, 54}}}},
216            {"<tool_call> {\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 18}}}},
217            {"<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 17}}}},
218            {"Let's call something\n<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{21, 38}}}},
219            {"Ok then<tool_call>{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 24}}}},
220            {"{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{0, 6}}}},
221            {"Ok then{\"name", {COMMON_REGEX_MATCH_TYPE_PARTIAL, {{7, 13}}}},
222            {"<tool_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 20}, {66, 66}, {0, 11}, {11, 20}, {66, 66}, {66, 66}}}},
223            {"<function_call> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 24}, {70, 70}, {0, 15}, {15, 24}, {70, 70}, {70, 70}}}},
224            {"<function name=\"special_function\"> {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 34}, {89, 89}, {89, 89}, {89, 89}, {89, 89}, {16, 32}}}},
225            {"<function=all>", {COMMON_REGEX_MATCH_TYPE_FULL, {{0, 14}, {14, 14}, {14, 14}, {14, 14}, {10, 13}, {14, 14}}}},
226
227        }
228    });
229}
230
231static void test_regex_to_reversed_partial_regex() {
232    printf("[%s]\n", __func__);
233
234    assert_equals<std::string>(
235        "^((?:(?:c)?b)?a)",
236        regex_to_reversed_partial_regex("abc"));
237
238    assert_equals<std::string>(
239        "^(a+)",
240        regex_to_reversed_partial_regex("a+"));
241
242    assert_equals<std::string>(
243        "^(a*)",
244        regex_to_reversed_partial_regex("a*"));
245
246    assert_equals<std::string>(
247        "^(a?)",
248        regex_to_reversed_partial_regex("a?"));
249
250    assert_equals<std::string>(
251        "^([a-z])",
252        regex_to_reversed_partial_regex("[a-z]"));
253
254    assert_equals<std::string>(
255        "^((?:\\w+)?[a-z])",
256        regex_to_reversed_partial_regex("[a-z]\\w+"));
257
258    assert_equals<std::string>(
259        "^((?:a|b))",
260        regex_to_reversed_partial_regex("(?:a|b)"));
261    assert_equals<std::string>(
262        "^((?:(?:(?:d)?c)?b)?a)",
263        regex_to_reversed_partial_regex("abcd"));
264    assert_equals<std::string>(
265        "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ??
266        regex_to_reversed_partial_regex("a*b"));
267    assert_equals<std::string>(
268        "^((?:(?:b)?a)?.*)",
269        regex_to_reversed_partial_regex(".*?ab"));
270    assert_equals<std::string>(
271        "^((?:(?:b)?.*)?a)",
272        regex_to_reversed_partial_regex("a.*?b"));
273    assert_equals<std::string>(
274        "^((?:(?:d)?(?:(?:c)?b))?a)",
275        regex_to_reversed_partial_regex("a(bc)d"));
276    assert_equals<std::string>(
277        "^((?:(?:(?:c)?b|(?:e)?d))?a)",
278        regex_to_reversed_partial_regex("a(bc|de)"));
279    assert_equals<std::string>(
280        "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)",
281        regex_to_reversed_partial_regex("ab{2,4}c"));
282}
283
284int main() {
285    test_regex_to_reversed_partial_regex();
286    test_regex();
287    std::cout << "All tests passed.\n";
288}