1#include "regex-partial.h"
  2#include "common.h"
  3#include <functional>
  4#include <optional>
  5
  6common_regex::common_regex(const std::string & pattern) :
  7    pattern(pattern),
  8    rx(pattern),
  9    rx_reversed_partial(regex_to_reversed_partial_regex(pattern)) {}
 10
 11common_regex_match common_regex::search(const std::string & input, size_t pos, bool as_match) const {
 12    std::smatch match;
 13    if (pos > input.size()) {
 14        throw std::runtime_error("Position out of bounds");
 15    }
 16    auto start = input.begin() + pos;
 17    auto found = as_match
 18        ? std::regex_match(start, input.end(), match, rx)
 19        : std::regex_search(start, input.end(), match, rx);
 20    if (found) {
 21        common_regex_match res;
 22        res.type = COMMON_REGEX_MATCH_TYPE_FULL;
 23        for (size_t i = 0; i < match.size(); ++i) {
 24            auto begin = pos + match.position(i);
 25            res.groups.emplace_back(begin, begin + match.length(i));
 26        }
 27        return res;
 28    }
 29    std::match_results<std::string::const_reverse_iterator> srmatch;
 30    if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) {
 31        auto group = srmatch[1].str();
 32        if (group.length() != 0) {
 33            auto it = srmatch[1].second.base();
 34            // auto position = static_cast<size_t>(std::distance(input.begin(), it));
 35            if ((!as_match) || it == input.begin()) {
 36                common_regex_match res;
 37                res.type = COMMON_REGEX_MATCH_TYPE_PARTIAL;
 38                const size_t begin = std::distance(input.begin(), it);
 39                const size_t end = input.size();
 40                if (begin == std::string::npos || end == std::string::npos || begin > end) {
 41                    throw std::runtime_error("Invalid range");
 42                }
 43                res.groups.push_back({begin, end});
 44                return res;
 45            }
 46        }
 47    }
 48    return {};
 49}
 50
 51/*
 52  Transforms a regex pattern to a partial match pattern that operates on a reversed input string to find partial final matches of the original pattern.
 53
 54  Ideally we'd like to use boost::match_partial (https://beta.boost.org/doc/libs/1_59_0/libs/regex/doc/html/boost_regex/partial_matches.html)
 55  to see if a string ends with a partial regex match, but but it's not in std::regex yet.
 56  Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
 57
 58  - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a)
 59  - /a|b/ -> ^(a|b)
 60  - /a*?/ -> error, could match ""
 61  - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager)
 62  - /.*?ab/ -> ^((?:b)?a) (omit .*)
 63  - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches)
 64  - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a)
 65  - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a)
 66  - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a)
 67
 68  The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern.
 69  All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored.
 70*/
 71std::string regex_to_reversed_partial_regex(const std::string & pattern) {
 72    auto it = pattern.begin();
 73    const auto end = pattern.end();
 74
 75    std::function<std::string()> process = [&]() {
 76        std::vector<std::vector<std::string>> alternatives(1);
 77        std::vector<std::string> * sequence = &alternatives.back();
 78
 79        while (it != end) {
 80            if (*it == '[') {
 81                auto start = it;
 82                ++it;
 83                while (it != end) {
 84                    if ((*it == '\\') && (++it != end)) {
 85                        ++it;
 86                    } else if ((it != end) && (*it == ']')) {
 87                        break;
 88                    } else {
 89                        ++it;
 90                    }
 91                }
 92                if (it == end) {
 93                    throw std::runtime_error("Unmatched '[' in pattern");
 94                }
 95                ++it;
 96                sequence->push_back(std::string(start, it));
 97            } else if (*it == '*' || *it == '?' || *it == '+') {
 98                if (sequence->empty()) {
 99                    throw std::runtime_error("Quantifier without preceding element");
100                }
101                sequence->back() += *it;
102                auto is_star = *it == '*';
103                ++it;
104                if (is_star) {
105                    if (*it == '?') {
106                        ++it;
107                    }
108                }
109            } else if (*it == '{') {
110                if (sequence->empty()) {
111                    throw std::runtime_error("Repetition without preceding element");
112                }
113                ++it;
114                auto start = it;
115                while (it != end && *it != '}') {
116                    ++it;
117                }
118                if (it == end) {
119                    throw std::runtime_error("Unmatched '{' in pattern");
120                }
121                auto parts = string_split(std::string(start, it), ",");
122                ++it;
123                if (parts.size() > 2) {
124                    throw std::runtime_error("Invalid repetition range in pattern");
125                }
126
127                auto parseOptInt = [&](const std::string & s, const std::optional<int> & def = std::nullopt) -> std::optional<int> {
128                    if (s.empty()) {
129                        return def;
130                    }
131                    return std::stoi(s);
132                };
133                auto min = parseOptInt(parts[0], 0);
134                auto max = parts.size() == 1 ? min : parseOptInt(parts[1]);
135                if (min && max && *max < *min) {
136                    throw std::runtime_error("Invalid repetition range in pattern");
137                }
138                // Brutal but... let's repeat at least min times, then ? for the delta between min & max (or * for unbounded)
139                auto part = sequence->back();
140                sequence->pop_back();
141                for (int i = 0; i < *min; i++) {
142                    sequence->push_back(part);
143                }
144                if (max) {
145                    for (int i = *min; i < *max; i++) {
146                        sequence->push_back(part + "?");
147                    }
148                } else {
149                    sequence->push_back(part + "*");
150                }
151            } else if (*it == '(') {
152                ++it;
153                if (it != end && *it == '?' && (it + 1 != end) && *(it + 1) == ':') {
154                    it += 2;
155                }
156                auto sub = process();
157                if (*it != ')') {
158                    throw std::runtime_error("Unmatched '(' in pattern");
159                }
160                ++it;
161                auto & part = sequence->emplace_back("(?:");
162                part += sub;
163                part += ")";
164            } else if (*it == ')') {
165                break;
166            } else if (*it == '|') {
167                ++it;
168                alternatives.emplace_back();
169                sequence = &alternatives.back();
170            } else if (*it == '\\' && (++it != end)) {
171                auto str = std::string("\\") + *it;
172                sequence->push_back(str);
173                ++it;
174            } else if (it != end) {
175                sequence->push_back(std::string(1, *it));
176                ++it;
177            }
178        }
179
180        // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a)
181        // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
182        // We'll do the outermost capturing group and final .* in the enclosing function.
183        std::vector<std::string> res_alts;
184        for (const auto & parts : alternatives) {
185            auto & res = res_alts.emplace_back();
186            for (size_t i = 0; i < parts.size() - 1; i++) {
187                res += "(?:";
188            }
189            for (auto it = parts.rbegin(); it != parts.rend(); ++it) {
190                res += *it;
191                if (it != parts.rend() - 1) {
192                    res += ")?";
193                }
194            }
195        }
196        return string_join(res_alts, "|");
197    };
198    auto res = process();
199    if (it != end) {
200        throw std::runtime_error("Unmatched '(' in pattern");
201    }
202
203    return "^(" + res + ")";
204}