1#include "lexer.h"
2#include "runtime.h"
3#include "parser.h"
4
5#include <algorithm>
6#include <memory>
7#include <stdexcept>
8#include <string>
9#include <vector>
10
11#define FILENAME "jinja-parser"
12
13namespace jinja {
14
15// Helper to check type without asserting (useful for logic)
16template<typename T>
17static bool is_type(const statement_ptr & ptr) {
18 return dynamic_cast<const T*>(ptr.get()) != nullptr;
19}
20
21class parser {
22 const std::vector<token> & tokens;
23 size_t current = 0;
24
25 std::string source; // for error reporting
26
27public:
28 parser(const std::vector<token> & t, const std::string & src) : tokens(t), source(src) {}
29
30 program parse() {
31 statements body;
32 while (current < tokens.size()) {
33 body.push_back(parse_any());
34 }
35 return program(std::move(body));
36 }
37
38 // NOTE: start_pos is the token index, used for error reporting
39 template<typename T, typename... Args>
40 std::unique_ptr<T> mk_stmt(size_t start_pos, Args&&... args) {
41 auto ptr = std::make_unique<T>(std::forward<Args>(args)...);
42 assert(start_pos < tokens.size());
43 ptr->pos = tokens[start_pos].pos;
44 return ptr;
45 }
46
47private:
48 const token & peek(size_t offset = 0) const {
49 if (current + offset >= tokens.size()) {
50 static const token end_token{token::eof, "", 0};
51 return end_token;
52 }
53 return tokens[current + offset];
54 }
55
56 token expect(token::type type, const std::string& error) {
57 const auto & t = peek();
58 if (t.t != type) {
59 throw parser_exception("Parser Error: " + error + " (Got " + t.value + ")", source, t.pos);
60 }
61 current++;
62 return t;
63 }
64
65 void expect_identifier(const std::string & name) {
66 const auto & t = peek();
67 if (t.t != token::identifier || t.value != name) {
68 throw parser_exception("Expected identifier: " + name, source, t.pos);
69 }
70 current++;
71 }
72
73 bool is(token::type type) const {
74 return peek().t == type;
75 }
76
77 bool is_identifier(const std::string & name) const {
78 return peek().t == token::identifier && peek().value == name;
79 }
80
81 bool is_statement(const std::vector<std::string> & names) const {
82 if (peek(0).t != token::open_statement || peek(1).t != token::identifier) {
83 return false;
84 }
85 std::string val = peek(1).value;
86 return std::find(names.begin(), names.end(), val) != names.end();
87 }
88
89 statement_ptr parse_any() {
90 size_t start_pos = current;
91 switch (peek().t) {
92 case token::comment:
93 return mk_stmt<comment_statement>(start_pos, tokens[current++].value);
94 case token::text:
95 return mk_stmt<string_literal>(start_pos, tokens[current++].value);
96 case token::open_statement:
97 return parse_jinja_statement();
98 case token::open_expression:
99 return parse_jinja_expression();
100 default:
101 throw std::runtime_error("Unexpected token type");
102 }
103 }
104
105 statement_ptr parse_jinja_expression() {
106 // Consume {{ }} tokens
107 expect(token::open_expression, "Expected {{");
108 auto result = parse_expression();
109 expect(token::close_expression, "Expected }}");
110 return result;
111 }
112
113 statement_ptr parse_jinja_statement() {
114 // Consume {% token
115 expect(token::open_statement, "Expected {%");
116
117 if (peek().t != token::identifier) {
118 throw std::runtime_error("Unknown statement");
119 }
120
121 size_t start_pos = current;
122 std::string name = peek().value;
123 current++; // consume identifier
124
125 statement_ptr result;
126 if (name == "set") {
127 result = parse_set_statement(start_pos);
128
129 } else if (name == "if") {
130 result = parse_if_statement(start_pos);
131 // expect {% endif %}
132 expect(token::open_statement, "Expected {%");
133 expect_identifier("endif");
134 expect(token::close_statement, "Expected %}");
135
136 } else if (name == "macro") {
137 result = parse_macro_statement(start_pos);
138 // expect {% endmacro %}
139 expect(token::open_statement, "Expected {%");
140 expect_identifier("endmacro");
141 expect(token::close_statement, "Expected %}");
142
143 } else if (name == "for") {
144 result = parse_for_statement(start_pos);
145 // expect {% endfor %}
146 expect(token::open_statement, "Expected {%");
147 expect_identifier("endfor");
148 expect(token::close_statement, "Expected %}");
149
150 } else if (name == "break") {
151 expect(token::close_statement, "Expected %}");
152 result = mk_stmt<break_statement>(start_pos);
153
154 } else if (name == "continue") {
155 expect(token::close_statement, "Expected %}");
156 result = mk_stmt<continue_statement>(start_pos);
157
158 } else if (name == "call") {
159 statements caller_args;
160 // bool has_caller_args = false;
161 if (is(token::open_paren)) {
162 // Optional caller arguments, e.g. {% call(user) dump_users(...) %}
163 caller_args = parse_args();
164 // has_caller_args = true;
165 }
166 auto callee = parse_primary_expression();
167 if (!is_type<identifier>(callee)) throw std::runtime_error("Expected identifier");
168
169 auto call_args = parse_args();
170 expect(token::close_statement, "Expected %}");
171
172 statements body;
173 while (!is_statement({"endcall"})) {
174 body.push_back(parse_any());
175 }
176
177 expect(token::open_statement, "Expected {%");
178 expect_identifier("endcall");
179 expect(token::close_statement, "Expected %}");
180
181 auto call_expr = mk_stmt<call_expression>(start_pos, std::move(callee), std::move(call_args));
182 result = mk_stmt<call_statement>(start_pos, std::move(call_expr), std::move(caller_args), std::move(body));
183
184 } else if (name == "filter") {
185 auto filter_node = parse_primary_expression();
186 if (is_type<identifier>(filter_node) && is(token::open_paren)) {
187 filter_node = parse_call_expression(std::move(filter_node));
188 }
189 expect(token::close_statement, "Expected %}");
190
191 statements body;
192 while (!is_statement({"endfilter"})) {
193 body.push_back(parse_any());
194 }
195
196 expect(token::open_statement, "Expected {%");
197 expect_identifier("endfilter");
198 expect(token::close_statement, "Expected %}");
199 result = mk_stmt<filter_statement>(start_pos, std::move(filter_node), std::move(body));
200
201 } else if (name == "generation" || name == "endgeneration") {
202 // Ignore generation blocks (transformers-specific)
203 // See https://github.com/huggingface/transformers/pull/30650 for more information.
204 result = mk_stmt<noop_statement>(start_pos);
205 current++;
206
207 } else {
208 throw std::runtime_error("Unknown statement: " + name);
209 }
210 return result;
211 }
212
213 statement_ptr parse_set_statement(size_t start_pos) {
214 // NOTE: `set` acts as both declaration statement and assignment expression
215 auto left = parse_expression_sequence();
216 statement_ptr value = nullptr;
217 statements body;
218
219 if (is(token::equals)) {
220 current++;
221 value = parse_expression_sequence();
222 } else {
223 // parsing multiline set here
224 expect(token::close_statement, "Expected %}");
225 while (!is_statement({"endset"})) {
226 body.push_back(parse_any());
227 }
228 expect(token::open_statement, "Expected {%");
229 expect_identifier("endset");
230 }
231 expect(token::close_statement, "Expected %}");
232 return mk_stmt<set_statement>(start_pos, std::move(left), std::move(value), std::move(body));
233 }
234
235 statement_ptr parse_if_statement(size_t start_pos) {
236 auto test = parse_expression();
237 expect(token::close_statement, "Expected %}");
238
239 statements body;
240 statements alternate;
241
242 // Keep parsing 'if' body until we reach the first {% elif %} or {% else %} or {% endif %}
243 while (!is_statement({"elif", "else", "endif"})) {
244 body.push_back(parse_any());
245 }
246
247 if (is_statement({"elif"})) {
248 size_t pos0 = current;
249 ++current; // consume {%
250 ++current; // consume 'elif'
251 alternate.push_back(parse_if_statement(pos0)); // nested If
252 } else if (is_statement({"else"})) {
253 ++current; // consume {%
254 ++current; // consume 'else'
255 expect(token::close_statement, "Expected %}");
256
257 // keep going until we hit {% endif %}
258 while (!is_statement({"endif"})) {
259 alternate.push_back(parse_any());
260 }
261 }
262 return mk_stmt<if_statement>(start_pos, std::move(test), std::move(body), std::move(alternate));
263 }
264
265 statement_ptr parse_macro_statement(size_t start_pos) {
266 auto name = parse_primary_expression();
267 auto args = parse_args();
268 expect(token::close_statement, "Expected %}");
269 statements body;
270 // Keep going until we hit {% endmacro
271 while (!is_statement({"endmacro"})) {
272 body.push_back(parse_any());
273 }
274 return mk_stmt<macro_statement>(start_pos, std::move(name), std::move(args), std::move(body));
275 }
276
277 statement_ptr parse_expression_sequence(bool primary = false) {
278 size_t start_pos = current;
279 statements exprs;
280 exprs.push_back(primary ? parse_primary_expression() : parse_expression());
281 bool is_tuple = is(token::comma);
282 while (is(token::comma)) {
283 current++; // consume comma
284 exprs.push_back(primary ? parse_primary_expression() : parse_expression());
285 }
286 return is_tuple ? mk_stmt<tuple_literal>(start_pos, std::move(exprs)) : std::move(exprs[0]);
287 }
288
289 statement_ptr parse_for_statement(size_t start_pos) {
290 // e.g., `message` in `for message in messages`
291 auto loop_var = parse_expression_sequence(true); // should be an identifier/tuple
292 if (!is_identifier("in")) throw std::runtime_error("Expected 'in'");
293 current++;
294
295 // `messages` in `for message in messages`
296 auto iterable = parse_expression();
297 expect(token::close_statement, "Expected %}");
298
299 statements body;
300 statements alternate;
301
302 // Keep going until we hit {% endfor or {% else
303 while (!is_statement({"endfor", "else"})) {
304 body.push_back(parse_any());
305 }
306
307 if (is_statement({"else"})) {
308 current += 2;
309 expect(token::close_statement, "Expected %}");
310 while (!is_statement({"endfor"})) {
311 alternate.push_back(parse_any());
312 }
313 }
314 return mk_stmt<for_statement>(
315 start_pos,
316 std::move(loop_var), std::move(iterable),
317 std::move(body), std::move(alternate));
318 }
319
320 statement_ptr parse_expression() {
321 // Choose parse function with lowest precedence
322 return parse_if_expression();
323 }
324
325 statement_ptr parse_if_expression() {
326 auto a = parse_logical_or_expression();
327 if (is_identifier("if")) {
328 // Ternary expression
329 size_t start_pos = current;
330 ++current; // consume 'if'
331 auto test = parse_logical_or_expression();
332 if (is_identifier("else")) {
333 // Ternary expression with else
334 size_t pos0 = current;
335 ++current; // consume 'else'
336 auto false_expr = parse_if_expression(); // recurse to support chained ternaries
337 return mk_stmt<ternary_expression>(pos0, std::move(test), std::move(a), std::move(false_expr));
338 } else {
339 // Select expression on iterable
340 return mk_stmt<select_expression>(start_pos, std::move(a), std::move(test));
341 }
342 }
343 return a;
344 }
345
346 statement_ptr parse_logical_or_expression() {
347 auto left = parse_logical_and_expression();
348 while (is_identifier("or")) {
349 size_t start_pos = current;
350 token op = tokens[current++];
351 left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_and_expression());
352 }
353 return left;
354 }
355
356 statement_ptr parse_logical_and_expression() {
357 auto left = parse_logical_negation_expression();
358 while (is_identifier("and")) {
359 size_t start_pos = current;
360 auto op = tokens[current++];
361 left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_logical_negation_expression());
362 }
363 return left;
364 }
365
366 statement_ptr parse_logical_negation_expression() {
367 // Try parse unary operators
368 if (is_identifier("not")) {
369 size_t start_pos = current;
370 auto op = tokens[current++];
371 return mk_stmt<unary_expression>(start_pos, op, parse_logical_negation_expression());
372 }
373 return parse_comparison_expression();
374 }
375
376 statement_ptr parse_comparison_expression() {
377 // NOTE: membership has same precedence as comparison
378 // e.g., ('a' in 'apple' == 'b' in 'banana') evaluates as ('a' in ('apple' == ('b' in 'banana')))
379 auto left = parse_additive_expression();
380 while (true) {
381 token op;
382 size_t start_pos = current;
383 if (is_identifier("not") && peek(1).t == token::identifier && peek(1).value == "in") {
384 op = {token::identifier, "not in", tokens[current].pos};
385 current += 2;
386 } else if (is_identifier("in")) {
387 op = tokens[current++];
388 } else if (is(token::comparison_binary_operator)) {
389 op = tokens[current++];
390 } else break;
391 left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_additive_expression());
392 }
393 return left;
394 }
395
396 statement_ptr parse_additive_expression() {
397 auto left = parse_multiplicative_expression();
398 while (is(token::additive_binary_operator)) {
399 size_t start_pos = current;
400 auto op = tokens[current++];
401 left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_multiplicative_expression());
402 }
403 return left;
404 }
405
406 statement_ptr parse_multiplicative_expression() {
407 auto left = parse_test_expression();
408 while (is(token::multiplicative_binary_operator)) {
409 size_t start_pos = current;
410 auto op = tokens[current++];
411 left = mk_stmt<binary_expression>(start_pos, op, std::move(left), parse_test_expression());
412 }
413 return left;
414 }
415
416 statement_ptr parse_test_expression() {
417 auto operand = parse_filter_expression();
418 while (is_identifier("is")) {
419 size_t start_pos = current;
420 current++;
421 bool negate = false;
422 if (is_identifier("not")) { current++; negate = true; }
423 auto test_id = parse_primary_expression();
424 // FIXME: tests can also be expressed like this: if x is eq 3
425 if (is(token::open_paren)) test_id = parse_call_expression(std::move(test_id));
426 operand = mk_stmt<test_expression>(start_pos, std::move(operand), negate, std::move(test_id));
427 }
428 return operand;
429 }
430
431 statement_ptr parse_filter_expression() {
432 auto operand = parse_call_member_expression();
433 while (is(token::pipe)) {
434 size_t start_pos = current;
435 current++;
436 auto filter = parse_primary_expression();
437 if (is(token::open_paren)) filter = parse_call_expression(std::move(filter));
438 operand = mk_stmt<filter_expression>(start_pos, std::move(operand), std::move(filter));
439 }
440 return operand;
441 }
442
443 statement_ptr parse_call_member_expression() {
444 // Handle member expressions recursively
445 auto member = parse_member_expression(parse_primary_expression());
446 return is(token::open_paren)
447 ? parse_call_expression(std::move(member)) // foo.x()
448 : std::move(member);
449 }
450
451 statement_ptr parse_call_expression(statement_ptr callee) {
452 size_t start_pos = current;
453 auto expr = mk_stmt<call_expression>(start_pos, std::move(callee), parse_args());
454 auto member = parse_member_expression(std::move(expr)); // foo.x().y
455 return is(token::open_paren)
456 ? parse_call_expression(std::move(member)) // foo.x()()
457 : std::move(member);
458 }
459
460 statements parse_args() {
461 // comma-separated arguments list
462 expect(token::open_paren, "Expected (");
463 statements args;
464 while (!is(token::close_paren)) {
465 statement_ptr arg;
466 // unpacking: *expr
467 if (peek().t == token::multiplicative_binary_operator && peek().value == "*") {
468 size_t start_pos = current;
469 ++current; // consume *
470 arg = mk_stmt<spread_expression>(start_pos, parse_expression());
471 } else {
472 arg = parse_expression();
473 if (is(token::equals)) {
474 // keyword argument
475 // e.g., func(x = 5, y = a or b)
476 size_t start_pos = current;
477 ++current; // consume equals
478 arg = mk_stmt<keyword_argument_expression>(start_pos, std::move(arg), parse_expression());
479 }
480 }
481 args.push_back(std::move(arg));
482 if (is(token::comma)) {
483 ++current; // consume comma
484 }
485 }
486 expect(token::close_paren, "Expected )");
487 return args;
488 }
489
490 statement_ptr parse_member_expression(statement_ptr object) {
491 size_t start_pos = current;
492 while (is(token::dot) || is(token::open_square_bracket)) {
493 auto op = tokens[current++];
494 bool computed = op.t == token::open_square_bracket;
495 statement_ptr prop;
496 if (computed) {
497 prop = parse_member_expression_arguments();
498 expect(token::close_square_bracket, "Expected ]");
499 } else {
500 prop = parse_primary_expression();
501 }
502 object = mk_stmt<member_expression>(start_pos, std::move(object), std::move(prop), computed);
503 }
504 return object;
505 }
506
507 statement_ptr parse_member_expression_arguments() {
508 // NOTE: This also handles slice expressions colon-separated arguments list
509 // e.g., ['test'], [0], [:2], [1:], [1:2], [1:2:3]
510 statements slices;
511 bool is_slice = false;
512 size_t start_pos = current;
513 while (!is(token::close_square_bracket)) {
514 if (is(token::colon)) {
515 // A case where a default is used
516 // e.g., [:2] will be parsed as [undefined, 2]
517 slices.push_back(nullptr);
518 ++current; // consume colon
519 is_slice = true;
520 } else {
521 slices.push_back(parse_expression());
522 if (is(token::colon)) {
523 ++current; // consume colon after expression, if it exists
524 is_slice = true;
525 }
526 }
527 }
528 if (is_slice) {
529 statement_ptr start = slices.size() > 0 ? std::move(slices[0]) : nullptr;
530 statement_ptr stop = slices.size() > 1 ? std::move(slices[1]) : nullptr;
531 statement_ptr step = slices.size() > 2 ? std::move(slices[2]) : nullptr;
532 return mk_stmt<slice_expression>(start_pos, std::move(start), std::move(stop), std::move(step));
533 }
534 return std::move(slices[0]);
535 }
536
537 statement_ptr parse_primary_expression() {
538 size_t start_pos = current;
539 auto t = tokens[current++];
540 switch (t.t) {
541 case token::numeric_literal:
542 if (t.value.find('.') != std::string::npos) {
543 return mk_stmt<float_literal>(start_pos, std::stod(t.value));
544 } else {
545 return mk_stmt<integer_literal>(start_pos, std::stoll(t.value));
546 }
547 case token::string_literal: {
548 std::string val = t.value;
549 while (is(token::string_literal)) {
550 val += tokens[current++].value;
551 }
552 return mk_stmt<string_literal>(start_pos, val);
553 }
554 case token::identifier:
555 return mk_stmt<identifier>(start_pos, t.value);
556 case token::open_paren: {
557 auto expr = parse_expression_sequence();
558 expect(token::close_paren, "Expected )");
559 return expr;
560 }
561 case token::open_square_bracket: {
562 statements vals;
563 while (!is(token::close_square_bracket)) {
564 vals.push_back(parse_expression());
565 if (is(token::comma)) current++;
566 }
567 current++;
568 return mk_stmt<array_literal>(start_pos, std::move(vals));
569 }
570 case token::open_curly_bracket: {
571 std::vector<std::pair<statement_ptr, statement_ptr>> pairs;
572 while (!is(token::close_curly_bracket)) {
573 auto key = parse_expression();
574 expect(token::colon, "Expected :");
575 pairs.push_back({std::move(key), parse_expression()});
576 if (is(token::comma)) current++;
577 }
578 current++;
579 return mk_stmt<object_literal>(start_pos, std::move(pairs));
580 }
581 default:
582 throw std::runtime_error("Unexpected token: " + t.value + " of type " + std::to_string(t.t));
583 }
584 }
585};
586
587program parse_from_tokens(const lexer_result & lexer_res) {
588 return parser(lexer_res.tokens, lexer_res.source).parse();
589}
590
591} // namespace jinja