1#include "lexer.h"
  2#include "runtime.h"
  3
  4#include <cctype>
  5#include <functional>
  6#include <map>
  7#include <string>
  8#include <vector>
  9
 10#define FILENAME "jinja-lexer"
 11
 12namespace jinja {
 13
 14static void string_lstrip(std::string & s, const char * chars) {
 15    size_t start = s.find_first_not_of(chars);
 16    if (start == std::string::npos) {
 17        s.clear();
 18    } else {
 19        s.erase(0, start);
 20    }
 21}
 22
 23static void string_rstrip(std::string & s, const char * chars) {
 24    size_t end = s.find_last_not_of(chars);
 25    if (end == std::string::npos) {
 26        s.clear();
 27    } else {
 28        s.erase(end + 1);
 29    }
 30}
 31
 32lexer_result lexer::tokenize(const std::string & source) {
 33    std::vector<token> tokens;
 34
 35    // NOTE: do NOT transform the source string (i.e. preprocessing), as we need to keep
 36    //       the original character positions for error reporting etc.
 37    std::string src = source;
 38
 39    if (source.empty()) {
 40        return {tokens, src};
 41    }
 42
 43    // Normalize \r\n or \r to \n
 44    for (std::string::size_type pos = 0; (pos = src.find("\r\n", pos)) != std::string::npos; ) {
 45        src.erase(pos, 1);
 46        ++pos;
 47    }
 48    for (std::string::size_type pos = 0; (pos = src.find("\r", pos)) != std::string::npos; ) {
 49        src.replace(pos, 1, 1, '\n');
 50        ++pos;
 51    }
 52
 53    // In the default configuration:
 54    //  - a single trailing newline is stripped if present
 55    //  - other whitespace (spaces, tabs, newlines etc.) is returned unchanged
 56    if (source.back() == '\n') {
 57        src.pop_back();
 58    }
 59
 60    size_t pos = 0;
 61    size_t start_pos = 0;
 62    size_t curly_bracket_depth = 0;
 63
 64    using pred = std::function<bool(char)>;
 65    auto consume_while = [&](const pred & predicate) -> std::string {
 66        std::string str;
 67        while (predicate(src[pos])) {
 68            // check for escape char
 69            if (src[pos] == '\\') {
 70                // consume backslash
 71                ++pos;
 72                // check for end of input
 73                if (pos >= src.size()) {
 74                    throw lexer_exception("unexpected end of input after escape character", source, pos);
 75                }
 76                // add escaped char
 77                char escaped_char = src[pos++];
 78                if (escape_chars.find(escaped_char) == escape_chars.end()) {
 79                    throw lexer_exception(std::string("unknown escape character \\") + escaped_char, source, pos);
 80                }
 81                char unescaped_char = escape_chars.at(escaped_char);
 82                str += unescaped_char;
 83                continue;
 84            }
 85
 86            str += src[pos++];
 87            if (pos > src.size()) {
 88                throw lexer_exception("unexpected end of input during consume_while", source, pos);
 89            }
 90        }
 91        return str;
 92    };
 93
 94    auto consume_numeric = [&]() -> std::string {
 95        std::string num = consume_while(is_integer);
 96        if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
 97            ++pos; // Consume '.'
 98            std::string frac = consume_while(is_integer);
 99            num += "." + frac;
100        }
101        return num;
102    };
103
104    auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
105        if (pos + n >= src.size()) return false;
106        for (char c : chars) {
107            if (src[pos + n] == c) return true;
108        }
109        return false;
110    };
111
112    // note: default config for chat template: lstrip_blocks = true, trim_blocks = true
113
114    // text\n[space]{block} --> text\n{block}
115    bool opt_lstrip_blocks = true;
116
117    // {block}\n[space]text --> {block}[space]text
118    bool opt_trim_blocks = true;
119
120    // options set dynamically based on current/last block
121    bool is_lstrip_block = false; // example: {%-
122    bool is_rstrip_block = false; // example: -%}
123
124    while (pos < src.size()) {
125        start_pos = pos;
126        // JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
127
128        // First, consume all text that is outside of a Jinja statement or expression
129        token::type last_token_type = tokens.empty()
130                                            ? token::close_statement // initial state
131                                            : tokens.back().t;
132        if (last_token_type == token::close_statement ||
133            last_token_type == token::close_expression ||
134            last_token_type == token::comment) {
135
136            bool last_block_can_rm_newline = false;
137            is_rstrip_block = false;
138            if (pos > 3) {
139                char c0 = src[pos - 3];
140                char c1 = src[pos - 2];
141                char c2 = src[pos - 1];
142                // strip if: -[%}#]}text
143                is_rstrip_block = c0 == '-'
144                                    && (c1 == '%' || c1 == '}' || c1 == '#')
145                                    && c2 == '}';
146                // match behavior of hf.js: exclude {{ and }} cases, regex: ([#%-]})
147                last_block_can_rm_newline = (c1 == '#' || c1 == '%' || c1 == '-') && c2 == '}';
148            }
149
150            size_t start = pos;
151            size_t end = start;
152            while (pos < src.size() &&
153                    // Keep going until we hit the next Jinja statement or expression
154                    !(
155                        src[pos] == '{' &&
156                        next_pos_is( {'%', '{', '#'} )
157                    )) {
158                end = ++pos;
159            }
160
161            // equivalent to hf.js code: template.replace(/^[ \t]*({[#%-])/gm, "$1");
162            if (opt_lstrip_blocks && src[pos] == '{' && next_pos_is({'%', '#', '-'})) {
163                size_t current = end;
164                while (current > start) {
165                    char c = src[current - 1];
166                    if (current == 1) {
167                        end = 0; // Trim from the start of the string
168                        break;
169                    }
170                    if (c == '\n') {
171                        end = current; // Trim from the start of the line
172                        break;
173                    }
174                    if (!std::isspace(static_cast<unsigned char>(c))) {
175                        break; // Found non-whitespace before newline, keep
176                    }
177                    --current;
178                }
179            }
180
181            std::string text = src.substr(start, end - start);
182
183            // equivalent to hf.js code: template.replace(/([#%-]})\n/g, "$1");
184            if (opt_trim_blocks && last_block_can_rm_newline) {
185                if (!text.empty() && text.front() == '\n') {
186                    text.erase(text.begin());
187                }
188            }
189
190            if (is_rstrip_block) {
191                // example: {last_block}[space]text
192                // doing lstrip on text, effectively rstrip the LAST block
193                // JJ_DEBUG("RSTRIP block detected, current text: '%s'", text.c_str());
194                string_lstrip(text, " \t\r\n");
195            }
196
197            is_lstrip_block = src[pos] == '{' && next_pos_is({'{', '%', '#'}) && next_pos_is({'-'}, 2);
198            if (is_lstrip_block) {
199                // example: text[space]{current_block}
200                // doing rstrip on text, effectively lstrip the CURRENT block
201                // JJ_DEBUG("LSTRIP block detected, current text: '%s'", text.c_str());
202                string_rstrip(text, " \t\r\n");
203            }
204
205            if (!text.empty()) {
206                // JJ_DEBUG("consumed text: '%s'", text.c_str());
207                tokens.push_back({token::text, text, start_pos});
208                continue;
209            }
210        }
211
212        // Possibly consume a comment
213        // TODO: handle lstrip/rstrip for comments? (not important for now)
214        if (src[pos] == '{' && next_pos_is( {'#'} )) {
215            start_pos = pos;
216            pos += 2; // Skip the opening {#
217            std::string comment;
218            while (!(src[pos] == '#' && next_pos_is( {'}'} ))) {
219                if (pos + 2 >= src.size()) {
220                    throw lexer_exception("missing end of comment tag", source, pos);
221                }
222                comment += src[pos++];
223            }
224            JJ_DEBUG("consumed comment: '%s'", comment.c_str());
225            tokens.push_back({token::comment, comment, start_pos});
226            pos += 2; // Skip the closing #}
227            continue;
228        }
229
230        if (src[pos] == '-' && (
231                last_token_type == token::open_expression ||
232                last_token_type == token::open_statement)
233        ) {
234            JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
235            pos++; // consume '-' in {%- or {{-
236            if (pos >= src.size()) break;
237        }
238
239        // Consume (and ignore) all whitespace inside Jinja statements or expressions
240        consume_while([](char c) { return std::isspace(static_cast<unsigned char>(c)); });
241
242        if (pos >= src.size()) break;
243
244        char ch = src[pos];
245
246        bool is_closing_block = ch == '-' && next_pos_is( {'%', '}'} );
247
248        // Check for unary operators
249        if (!is_closing_block && (ch == '-' || ch == '+')) {
250            start_pos = pos;
251            token::type last_token_type = tokens.empty() ? token::eof : tokens.back().t;
252            if (last_token_type == token::text || last_token_type == token::eof) {
253                throw lexer_exception(std::string("unexpected character: ") + ch, source, pos);
254            }
255            switch (last_token_type) {
256                case token::identifier:
257                case token::numeric_literal:
258                case token::string_literal:
259                case token::close_paren:
260                case token::close_square_bracket:
261                    // Part of a binary operator
262                    // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1
263                    // Continue parsing normally
264                    break;
265                default: {
266                    // Is part of a unary operator
267                    // (-1), [-1], (1 + -1), not -1, -apple
268                    ++pos; // Consume the operator
269
270                    // Check for numbers following the unary operator
271                    std::string num = consume_numeric();
272                    std::string value = std::string(1, ch) + num;
273                    token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
274                    // JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
275                    tokens.push_back({t, value, start_pos});
276                    continue;
277                }
278            }
279        }
280
281        // Try to match one of the tokens in the mapping table
282        bool matched = false;
283        for (const auto & [seq, typ] : ordered_mapping_table) {
284            start_pos = pos;
285            // Inside an object literal, don't treat "}}" as expression-end
286            if (seq == "}}" && curly_bracket_depth > 0) {
287                continue;
288            }
289            if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) {
290                tokens.push_back({typ, seq, start_pos});
291                if (typ == token::open_expression) {
292                    curly_bracket_depth = 0;
293                } else if (typ == token::open_curly_bracket) {
294                    ++curly_bracket_depth;
295                } else if (typ == token::close_curly_bracket) {
296                    --curly_bracket_depth;
297                }
298
299                pos += seq.size();
300                matched = true;
301                break; // continue main loop
302            }
303        }
304        if (matched) continue; // continue main loop
305
306        // Strings
307        if (ch == '\'' || ch == '"') {
308            start_pos = pos;
309            ++pos; // Skip opening quote
310            std::string str = consume_while([ch](char c) { return c != ch; });
311            // JJ_DEBUG("consumed string literal: '%s'", str.c_str());
312            tokens.push_back({token::string_literal, str, start_pos});
313            ++pos; // Skip closing quote
314            continue;
315        }
316
317        // Numbers
318        if (is_integer(ch)) {
319            start_pos = pos;
320            std::string num = consume_numeric();
321            // JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
322            tokens.push_back({token::numeric_literal, num, start_pos});
323            continue;
324        }
325
326        // Identifiers
327        if (is_word(ch)) {
328            start_pos = pos;
329            std::string word = consume_while(is_word);
330            // JJ_DEBUG("consumed identifier: '%s'", word.c_str());
331            tokens.push_back({token::identifier, word, start_pos});
332            continue;
333        }
334
335        throw lexer_exception(std::string("unexpected character: ") + ch, source, pos);
336    }
337
338    return {std::move(tokens), src};
339}
340
341} // namespace jinja