1#include "lexer.h"
  2#include "runtime.h"
  3#include "parser.h"
  4
  5#include <algorithm>
  6#include <memory>
  7#include <stdexcept>
  8#include <string>
  9#include <vector>
 10
 11#define FILENAME "jinja-parser"
 12
 13namespace jinja {
 14
 15// Helper to check type without asserting (useful for logic)
 16template<typename T>
 17static bool is_type(const statement_ptr & ptr) {
 18    return dynamic_cast<const T*>(ptr.get()) != nullptr;
 19}
 20
 21class parser {
 22    const std::vector<token> & tokens;
 23    size_t current = 0;
 24
 25    std::string source; // for error reporting
 26
 27public:
 28    parser(const std::vector<token> & t, const std::string & src) : tokens(t), source(src) {}
 29
 30    program parse() {
 31        statements body;
 32        while (current < tokens.size()) {
 33            body.push_back(parse_any());
 34        }
 35        return program(std::move(body));
 36    }
 37
 38    // NOTE: start_pos is the token index, used for error reporting
 39    template<typename T, typename... Args>
 40    std::unique_ptr<T> mk_stmt(size_t start_pos, Args&&... args) {
 41        auto ptr = std::make_unique<T>(std::forward<Args>(args)...);
 42        assert(start_pos < tokens.size());
 43        ptr->pos = tokens[start_pos].pos;
 44        return ptr;
 45    }
 46
 47private:
 48    const token & peek(size_t offset = 0) const {
 49        if (current + offset >= tokens.size()) {
 50            static const token end_token{token::eof, "", 0};
 51            return end_token;
 52        }
 53        return tokens[current + offset];
 54    }
 55
 56    token expect(token::type type, const std::string&  error) {
 57        const auto & t = peek();
 58        if (t.t != type) {
 59            throw parser_exception("Parser Error: " + error + " (Got " + t.value + ")", source, t.pos);
 60        }
 61        current++;
 62        return t;
 63    }
 64
 65    void expect_identifier(const std::string & name) {
 66        const auto & t = peek();
 67        if (t.t != token::identifier || t.value != name) {
 68            throw parser_exception("Expected identifier: " + name, source, t.pos);
 69        }
 70        current++;
 71    }
 72
 73    bool is(token::type type) const {
 74        return peek().t == type;
 75    }
 76
 77    bool is_identifier(const std::string & name) const {
 78        return peek().t == token::identifier && peek().value == name;
 79    }
 80
 81    bool is_statement(const std::vector<std::string> & names) const {
 82        if (peek(0).t != token::open_statement || peek(1).t != token::identifier) {
 83            return false;
 84        }
 85        std::string val = peek(1).value;
 86        return std::find(names.begin(), names.end(), val) != names.end();
 87    }
 88
 89    statement_ptr parse_any() {
 90        size_t start_pos = current;
 91        switch (peek().t) {
 92            case token::comment:
 93                return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
 94            case token::text:
 95                return mk_stmt<string_literal>(start_pos, tokens[current++].value);
 96            case token::open_statement:
 97                return parse_jinja_statement();
 98            case token::open_expression:
 99                return parse_jinja_expression();
100            default:
101                throw std::runtime_error("Unexpected token type");
102        }
103    }
104
105    statement_ptr parse_jinja_expression() {
106        // Consume {{ }} tokens
107        expect(token::open_expression, "Expected {{");
108        auto result = parse_expression();
109        expect(token::close_expression, "Expected }}");
110        return result;
111    }
112
113    statement_ptr parse_jinja_statement() {
114        // Consume {% token
115        expect(token::open_statement, "Expected {%");
116
117        if (peek().t != token::identifier) {
118            throw std::runtime_error("Unknown statement");
119        }
120
121        size_t start_pos = current;
122        std::string name = peek().value;
123        current++; // consume identifier
124
125        statement_ptr result;
126        if (name == "set") {
127            result = parse_set_statement(start_pos);
128
129        } else if (name == "if") {
130            result = parse_if_statement(start_pos);
131            // expect {% endif %}
132            expect(token::open_statement, "Expected {%");
133            expect_identifier("endif");
134            expect(token::close_statement, "Expected %}");
135
136        } else if (name == "macro") {
137            result = parse_macro_statement(start_pos);
138            // expect {% endmacro %}
139            expect(token::open_statement, "Expected {%");
140            expect_identifier("endmacro");
141            expect(token::close_statement, "Expected %}");
142
143        } else if (name == "for") {
144            result = parse_for_statement(start_pos);
145            // expect {% endfor %}
146            expect(token::open_statement, "Expected {%");
147            expect_identifier("endfor");
148            expect(token::close_statement, "Expected %}");
149
150        } else if (name == "break") {
151            expect(token::close_statement, "Expected %}");
152            result = mk_stmt<break_statement>(start_pos);
153
154        } else if (name == "continue") {
155            expect(token::close_statement, "Expected %}");
156            result = mk_stmt<continue_statement>(start_pos);
157
158        } else if (name == "call") {
159            statements caller_args;
160            // bool has_caller_args = false;
161            if (is(token::open_paren)) {
162                // Optional caller arguments, e.g. {% call(user) dump_users(...) %}
163                caller_args = parse_args();
164                // has_caller_args = true;
165            }
166            auto callee = parse_primary_expression();
167            if (!is_type<identifier>(callee)) throw std::runtime_error("Expected identifier");
168
169            auto call_args = parse_args();
170            expect(token::close_statement, "Expected %}");
171
172            statements body;
173            while (!is_statement({"endcall"})) {
174                body.push_back(parse_any());
175            }
176
177            expect(token::open_statement, "Expected {%");
178            expect_identifier("endcall");
179            expect(token::close_statement, "Expected %}");
180
181            auto call_expr = mk_stmt<call_expression>(start_pos, std::move(callee), std::move(call_args));
182            result = mk_stmt<call_statement>(start_pos, std::move(call_expr), std::move(caller_args), std::move(body));
183
184        } else if (name == "filter") {
185            auto filter_node = parse_primary_expression();
186            if (is_type<identifier>(filter_node) && is(token::open_paren)) {
187                filter_node = parse_call_expression(std::move(filter_node));
188            }
189            expect(token::close_statement, "Expected %}");
190
191            statements body;
192            while (!is_statement({"endfilter"})) {
193                body.push_back(parse_any());
194            }
195
196            expect(token::open_statement, "Expected {%");
197            expect_identifier("endfilter");
198            expect(token::close_statement, "Expected %}");
199            result = mk_stmt<filter_statement>(start_pos, std::move(filter_node), std::move(body));
200
201        } else if (name == "generation" || name == "endgeneration") {
202            // Ignore generation blocks (transformers-specific)
203            // See https://github.com/huggingface/transformers/pull/30650 for more information.
204            result = mk_stmt<noop_statement>(start_pos);
205            current++;
206
207        } else {
208            throw std::runtime_error("Unknown statement: " + name);
209        }
210        return result;
211    }
212
213    statement_ptr parse_set_statement(size_t start_pos) {
214        // NOTE: `set` acts as both declaration statement and assignment expression
215        auto left = parse_expression_sequence();
216        statement_ptr value = nullptr;
217        statements body;
218
219        if (is(token::equals)) {
220            current++;
221            value = parse_expression_sequence();
222        } else {
223            // parsing multiline set here
224            expect(token::close_statement, "Expected %}");
225            while (!is_statement({"endset"})) {
226                body.push_back(parse_any());
227            }
228            expect(token::open_statement, "Expected {%");
229            expect_identifier("endset");
230        }
231        expect(token::close_statement, "Expected %}");
232        return mk_stmt<set_statement>(start_pos, std::move(left), std::move(value), std::move(body));
233    }
234
235    statement_ptr parse_if_statement(size_t start_pos) {
236        auto test = parse_expression();
237        expect(token::close_statement, "Expected %}");
238
239        statements body;
240        statements alternate;
241
242        // Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %}
243        while (!is_statement({"elif", "else", "endif"})) {
244            body.push_back(parse_any());
245        }
246
247        if (is_statement({"elif"})) {
248            size_t pos0 = current;
249            ++current; // consume {%
250            ++current; // consume 'elif'
251            alternate.push_back(parse_if_statement(pos0)); // nested If
252        } else if (is_statement({"else"})) {
253            ++current; // consume {%
254            ++current; // consume 'else'
255            expect(token::close_statement, "Expected %}");
256
257            // keep going until we hit {% endif %}
258            while (!is_statement({"endif"})) {
259                alternate.push_back(parse_any());
260            }
261        }
262        return mk_stmt<if_statement>(start_pos, std::move(test), std::move(body), std::move(alternate));
263    }
264
265    statement_ptr parse_macro_statement(size_t start_pos) {
266        auto name = parse_primary_expression();
267        auto args = parse_args();
268        expect(token::close_statement, "Expected %}");
269        statements body;
270        // Keep going until we hit {% endmacro
271        while (!is_statement({"endmacro"})) {
272            body.push_back(parse_any());
273        }
274        return mk_stmt<macro_statement>(start_pos, std::move(name), std::move(args), std::move(body));
275    }
276
277    statement_ptr parse_expression_sequence(bool primary = false) {
278        size_t start_pos = current;
279        statements exprs;
280        exprs.push_back(primary ? parse_primary_expression() : parse_expression());
281        bool is_tuple = is(token::comma);
282        while (is(token::comma)) {
283            current++; // consume comma
284            exprs.push_back(primary ? parse_primary_expression() : parse_expression());
285        }
286        return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
287    }
288
289    statement_ptr parse_for_statement(size_t start_pos) {
290        // e.g., `message` in `for message in messages`
291        auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
292        if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
293        current++;
294
295        // `messages` in `for message in messages`
296        auto iterable = parse_expression();
297        expect(token::close_statement, "Expected %}");
298
299        statements body;
300        statements alternate;
301
302        // Keep going until we hit {% endfor or {% else
303        while (!is_statement({"endfor", "else"})) {
304            body.push_back(parse_any());
305        }
306
307        if (is_statement({"else"})) {
308            current += 2;
309            expect(token::close_statement, "Expected %}");
310            while (!is_statement({"endfor"})) {
311                alternate.push_back(parse_any());
312            }
313        }
314        return mk_stmt<for_statement>(
315            start_pos,
316            std::move(loop_var), std::move(iterable),
317            std::move(body), std::move(alternate));
318    }
319
320    statement_ptr parse_expression() {
321        // Choose parse function with lowest precedence
322        return parse_if_expression();
323    }
324
325    statement_ptr parse_if_expression() {
326        auto a = parse_logical_or_expression();
327        if (is_identifier("if")) {
328            // Ternary expression
329            size_t start_pos = current;
330            ++current; // consume 'if'
331            auto test = parse_logical_or_expression();
332            if (is_identifier("else")) {
333                // Ternary expression with else
334                size_t pos0 = current;
335                ++current; // consume 'else'
336                auto false_expr = parse_if_expression(); // recurse to support chained ternaries
337                return mk_stmt<ternary_expression>(pos0, std::move(test), std::move(a), std::move(false_expr));
338            } else {
339                // Select expression on iterable
340                return mk_stmt<select_expression>(start_pos, std::move(a), std::move(test));
341            }
342        }
343        return a;
344    }
345
346    statement_ptr parse_logical_or_expression() {
347        auto left = parse_logical_and_expression();
348        while (is_identifier("or")) {
349            size_t start_pos = current;
350            token op = tokens[current++];
351            left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
352        }
353        return left;
354    }
355
356    statement_ptr parse_logical_and_expression() {
357        auto left = parse_logical_negation_expression();
358        while (is_identifier("and")) {
359            size_t start_pos = current;
360            auto op = tokens[current++];
361            left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
362        }
363        return left;
364    }
365
366    statement_ptr parse_logical_negation_expression() {
367        // Try parse unary operators
368        if (is_identifier("not")) {
369            size_t start_pos = current;
370            auto op = tokens[current++];
371            return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
372        }
373        return parse_comparison_expression();
374    }
375
376    statement_ptr parse_comparison_expression() {
377        // NOTE: membership has same precedence as comparison
378        // e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana')))
379        auto left = parse_additive_expression();
380        while (true) {
381            token op;
382            size_t start_pos = current;
383            if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
384                op = {token::identifier, "not in", tokens[current].pos};
385                current += 2;
386            } else if (is_identifier("in")) {
387                op = tokens[current++];
388            } else if (is(token::comparison_binary_operator)) {
389                op = tokens[current++];
390            } else break;
391            left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
392        }
393        return left;
394    }
395
396    statement_ptr parse_additive_expression() {
397        auto left = parse_multiplicative_expression();
398        while (is(token::additive_binary_operator)) {
399            size_t start_pos = current;
400            auto op = tokens[current++];
401            left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
402        }
403        return left;
404    }
405
406    statement_ptr parse_multiplicative_expression() {
407        auto left = parse_test_expression();
408        while (is(token::multiplicative_binary_operator)) {
409            size_t start_pos = current;
410            auto op = tokens[current++];
411            left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
412        }
413        return left;
414    }
415
416    statement_ptr parse_test_expression() {
417        auto operand = parse_filter_expression();
418        while (is_identifier("is")) {
419            size_t start_pos = current;
420            current++;
421            bool negate = false;
422            if (is_identifier("not")) { current++; negate = true; }
423            auto test_id = parse_primary_expression();
424            // FIXME: tests can also be expressed like this: if x is eq 3
425            if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
426            operand = mk_stmt<test_expression>(start_pos, std::move(operand), negate, std::move(test_id));
427        }
428        return operand;
429    }
430
431    statement_ptr parse_filter_expression() {
432        auto operand = parse_call_member_expression();
433        while (is(token::pipe)) {
434            size_t start_pos = current;
435            current++;
436            auto filter = parse_primary_expression();
437            if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
438            operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
439        }
440        return operand;
441    }
442
443    statement_ptr parse_call_member_expression() {
444        // Handle member expressions recursively
445        auto member = parse_member_expression(parse_primary_expression());
446        return is(token::open_paren)
447            ? parse_call_expression(std::move(member)) // foo.x()
448            : std::move(member);
449    }
450
451    statement_ptr parse_call_expression(statement_ptr callee) {
452        size_t start_pos = current;
453        auto expr = mk_stmt<call_expression>(start_pos, std::move(callee), parse_args());
454        auto member = parse_member_expression(std::move(expr)); // foo.x().y
455        return is(token::open_paren)
456            ? parse_call_expression(std::move(member)) // foo.x()()
457            : std::move(member);
458    }
459
460    statements parse_args() {
461        // comma-separated arguments list
462        expect(token::open_paren, "Expected (");
463        statements args;
464        while (!is(token::close_paren)) {
465            statement_ptr arg;
466            // unpacking: *expr
467            if (peek().t == token::multiplicative_binary_operator && peek().value == "*") {
468                size_t start_pos = current;
469                ++current; // consume *
470                arg = mk_stmt<spread_expression>(start_pos, parse_expression());
471            } else {
472                arg = parse_expression();
473                if (is(token::equals)) {
474                    // keyword argument
475                    // e.g., func(x = 5, y = a or b)
476                    size_t start_pos = current;
477                    ++current; // consume equals
478                    arg = mk_stmt<keyword_argument_expression>(start_pos, std::move(arg), parse_expression());
479                }
480            }
481            args.push_back(std::move(arg));
482            if (is(token::comma)) {
483                ++current; // consume comma
484            }
485        }
486        expect(token::close_paren, "Expected )");
487        return args;
488    }
489
490    statement_ptr parse_member_expression(statement_ptr object) {
491        size_t start_pos = current;
492        while (is(token::dot) || is(token::open_square_bracket)) {
493            auto op = tokens[current++];
494            bool computed = op.t == token::open_square_bracket;
495            statement_ptr prop;
496            if (computed) {
497                prop = parse_member_expression_arguments();
498                expect(token::close_square_bracket, "Expected ]");
499            } else {
500                prop = parse_primary_expression();
501            }
502            object = mk_stmt<member_expression>(start_pos, std::move(object), std::move(prop), computed);
503        }
504        return object;
505    }
506
507    statement_ptr parse_member_expression_arguments() {
508        // NOTE: This also handles slice expressions colon-separated arguments list
509        // e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3]
510        statements slices;
511        bool is_slice = false;
512        size_t start_pos = current;
513        while (!is(token::close_square_bracket)) {
514            if (is(token::colon)) {
515                // A case where a default is used
516                // e.g., [:2] will be parsed as [undefined, 2]
517                slices.push_back(nullptr);
518                ++current; // consume colon
519                is_slice = true;
520            } else {
521                slices.push_back(parse_expression());
522                if (is(token::colon)) {
523                    ++current; // consume colon after expression, if it exists
524                    is_slice = true;
525                }
526            }
527        }
528        if (is_slice) {
529            statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr;
530            statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr;
531            statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
532            return mk_stmt<slice_expression>(start_pos, std::move(start), std::move(stop), std::move(step));
533        }
534        return std::move(slices[0]);
535    }
536
537    statement_ptr parse_primary_expression() {
538        size_t start_pos = current;
539        auto t = tokens[current++];
540        switch (t.t) {
541            case token::numeric_literal:
542                if (t.value.find('.') != std::string::npos) {
543                    return mk_stmt<float_literal>(start_pos, std::stod(t.value));
544                } else {
545                    return mk_stmt<integer_literal>(start_pos, std::stoll(t.value));
546                }
547            case token::string_literal: {
548                std::string val = t.value;
549                while (is(token::string_literal)) {
550                    val += tokens[current++].value;
551                }
552                return mk_stmt<string_literal>(start_pos, val);
553            }
554            case token::identifier:
555                return mk_stmt<identifier>(start_pos, t.value);
556            case token::open_paren: {
557                auto expr = parse_expression_sequence();
558                expect(token::close_paren, "Expected )");
559                return expr;
560            }
561            case token::open_square_bracket: {
562                statements vals;
563                while (!is(token::close_square_bracket)) {
564                    vals.push_back(parse_expression());
565                    if (is(token::comma)) current++;
566                }
567                current++;
568                return mk_stmt<array_literal>(start_pos, std::move(vals));
569            }
570            case token::open_curly_bracket: {
571                std::vector<std::pair<statement_ptr, statement_ptr>> pairs;
572                while (!is(token::close_curly_bracket)) {
573                    auto key = parse_expression();
574                    expect(token::colon, "Expected :");
575                    pairs.push_back({std::move(key), parse_expression()});
576                    if (is(token::comma)) current++;
577                }
578                current++;
579                return mk_stmt<object_literal>(start_pos, std::move(pairs));
580            }
581            default:
582                throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t));
583        }
584    }
585};
586
587program parse_from_tokens(const lexer_result & lexer_res) {
588    return parser(lexer_res.tokens, lexer_res.source).parse();
589}
590
591} // namespace jinja