1#pragma once
  2
  3#include "lexer.h"
  4#include "value.h"
  5
  6#include <cassert>
  7#include <ctime>
  8#include <memory>
  9#include <sstream>
 10#include <string>
 11#include <vector>
 12
 13#define JJ_DEBUG(msg, ...)  do { if (g_jinja_debug) printf("%s:%-3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__); } while (0)
 14
 15extern bool g_jinja_debug;
 16
 17namespace jinja {
 18
 19struct statement;
 20using statement_ptr = std::unique_ptr<statement>;
 21using statements = std::vector<statement_ptr>;
 22
 23// Helpers for dynamic casting and type checking
 24template<typename T>
 25struct extract_pointee_unique {
 26    using type = T;
 27};
 28template<typename U>
 29struct extract_pointee_unique<std::unique_ptr<U>> {
 30    using type = U;
 31};
 32template<typename T>
 33bool is_stmt(const statement_ptr & ptr) {
 34    return dynamic_cast<const T*>(ptr.get()) != nullptr;
 35}
 36template<typename T>
 37T * cast_stmt(statement_ptr & ptr) {
 38    return dynamic_cast<T*>(ptr.get());
 39}
 40template<typename T>
 41const T * cast_stmt(const statement_ptr & ptr) {
 42    return dynamic_cast<const T*>(ptr.get());
 43}
 44// End Helpers
 45
 46
 47// not thread-safe
 48void enable_debug(bool enable);
 49
 50struct context {
 51    std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation
 52    std::time_t current_time; // for functions that need current time
 53
 54    bool is_get_stats = false; // whether to collect stats
 55
 56    // src is optional, used for error reporting
 57    context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) {
 58        env = mk_val<value_object>();
 59        env->has_builtins = false; // context object has no builtins
 60        env->insert("true",  mk_val<value_bool>(true));
 61        env->insert("True",  mk_val<value_bool>(true));
 62        env->insert("false", mk_val<value_bool>(false));
 63        env->insert("False", mk_val<value_bool>(false));
 64        env->insert("none",  mk_val<value_none>());
 65        env->insert("None",  mk_val<value_none>());
 66        current_time = std::time(nullptr);
 67    }
 68    ~context() = default;
 69
 70    context(const context & parent) : context() {
 71        // inherit variables (for example, when entering a new scope)
 72        auto & pvar = parent.env->as_ordered_object();
 73        for (const auto & pair : pvar) {
 74            set_val(pair.first, pair.second);
 75        }
 76        current_time = parent.current_time;
 77        is_get_stats = parent.is_get_stats;
 78        src = parent.src;
 79    }
 80
 81    value get_val(const std::string & name) {
 82        value default_val = mk_val<value_undefined>(name);
 83        return env->at(name, default_val);
 84    }
 85
 86    void set_val(const std::string & name, const value & val) {
 87        env->insert(name, val);
 88    }
 89
 90    void set_val(const value & name, const value & val) {
 91        env->insert(name, val);
 92    }
 93
 94    void print_vars() const {
 95        printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str());
 96    }
 97
 98private:
 99    value_object env;
100};
101
102/**
103 * Base class for all nodes in the AST.
104 */
105struct statement {
106    size_t pos; // position in source, for debugging
107    virtual ~statement() = default;
108    virtual std::string type() const { return "Statement"; }
109    // execute_impl must be overridden by derived classes
110    virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); }
111    // execute is the public method to execute a statement with error handling
112    value execute(context &);
113};
114
115// Type Checking Utilities
116
117template<typename T>
118static void chk_type(const statement_ptr & ptr) {
119    if (!ptr) return; // Allow null for optional fields
120    assert(dynamic_cast<T *>(ptr.get()) != nullptr);
121}
122
123template<typename T, typename U>
124static void chk_type(const statement_ptr & ptr) {
125    if (!ptr) return;
126    assert(dynamic_cast<T *>(ptr.get()) != nullptr || dynamic_cast<U *>(ptr.get()) != nullptr);
127}
128
129// Base Types
130
131/**
132 * Expressions will result in a value at runtime (unlike statements).
133 */
134struct expression : public statement {
135    std::string type() const override { return "Expression"; }
136};
137
138// Statements
139
140struct program : public statement {
141    statements body;
142
143    program() = default;
144    explicit program(statements && body) : body(std::move(body)) {}
145    std::string type() const override { return "Program"; }
146    value execute_impl(context &) override {
147        throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead");
148    }
149};
150
151struct if_statement : public statement {
152    statement_ptr test;
153    statements body;
154    statements alternate;
155
156    if_statement(statement_ptr && test, statements && body, statements && alternate)
157        : test(std::move(test)), body(std::move(body)), alternate(std::move(alternate)) {
158        chk_type<expression>(this->test);
159    }
160
161    std::string type() const override { return "If"; }
162    value execute_impl(context & ctx) override;
163};
164
165struct identifier;
166struct tuple_literal;
167
168/**
169 * Loop over each item in a sequence
170 * https://jinja.palletsprojects.com/en/3.0.x/templates/#for
171 */
172struct for_statement : public statement {
173    statement_ptr loopvar; // Identifier | TupleLiteral
174    statement_ptr iterable;
175    statements body;
176    statements default_block; // if no iteration took place
177
178    for_statement(statement_ptr && loopvar, statement_ptr && iterable, statements && body, statements && default_block)
179        : loopvar(std::move(loopvar)), iterable(std::move(iterable)),
180          body(std::move(body)), default_block(std::move(default_block)) {
181        chk_type<identifier, tuple_literal>(this->loopvar);
182        chk_type<expression>(this->iterable);
183    }
184
185    std::string type() const override { return "For"; }
186    value execute_impl(context & ctx) override;
187};
188
189struct break_statement : public statement {
190    std::string type() const override { return "Break"; }
191
192    struct signal : public std::exception {
193        const char* what() const noexcept override {
194            return "Break statement executed";
195        }
196    };
197
198    value execute_impl(context &) override {
199        throw break_statement::signal();
200    }
201};
202
203struct continue_statement : public statement {
204    std::string type() const override { return "Continue"; }
205
206    struct signal : public std::exception {
207        const char* what() const noexcept override {
208            return "Continue statement executed";
209        }
210    };
211
212    value execute_impl(context &) override {
213        throw continue_statement::signal();
214    }
215};
216
217// do nothing
218struct noop_statement : public statement {
219    std::string type() const override { return "Noop"; }
220    value execute_impl(context &) override {
221        return mk_val<value_undefined>();
222    }
223};
224
225struct set_statement : public statement {
226    statement_ptr assignee;
227    statement_ptr val;
228    statements body;
229
230    set_statement(statement_ptr && assignee, statement_ptr && value, statements && body)
231        : assignee(std::move(assignee)), val(std::move(value)), body(std::move(body)) {
232        chk_type<expression>(this->assignee);
233        chk_type<expression>(this->val);
234    }
235
236    std::string type() const override { return "Set"; }
237    value execute_impl(context & ctx) override;
238};
239
240struct macro_statement : public statement {
241    statement_ptr name;
242    statements args;
243    statements body;
244
245    macro_statement(statement_ptr && name, statements && args, statements && body)
246        : name(std::move(name)), args(std::move(args)), body(std::move(body)) {
247        chk_type<identifier>(this->name);
248        for (const auto& arg : this->args) chk_type<expression>(arg);
249    }
250
251    std::string type() const override { return "Macro"; }
252    value execute_impl(context & ctx) override;
253};
254
255struct comment_statement : public statement {
256    std::string val;
257    explicit comment_statement(const std::string & v) : val(v) {}
258    std::string type() const override { return "Comment"; }
259    value execute_impl(context &) override {
260        return mk_val<value_undefined>();
261    }
262};
263
264// Expressions
265
266struct member_expression : public expression {
267    statement_ptr object;
268    statement_ptr property;
269    bool computed; // true if obj[expr] and false if obj.prop
270
271    member_expression(statement_ptr && object, statement_ptr && property, bool computed)
272        : object(std::move(object)), property(std::move(property)), computed(computed) {
273        chk_type<expression>(this->object);
274        chk_type<expression>(this->property);
275    }
276    std::string type() const override { return "MemberExpression"; }
277    value execute_impl(context & ctx) override;
278};
279
280struct call_expression : public expression {
281    statement_ptr callee;
282    statements args;
283
284    call_expression(statement_ptr && callee, statements && args)
285        : callee(std::move(callee)), args(std::move(args)) {
286        chk_type<expression>(this->callee);
287        for (const auto& arg : this->args) chk_type<expression>(arg);
288    }
289    std::string type() const override { return "CallExpression"; }
290    value execute_impl(context & ctx) override;
291};
292
293/**
294 * Represents a user-defined variable or symbol in the template.
295 */
296struct identifier : public expression {
297    std::string val;
298    explicit identifier(const std::string & val) : val(val) {}
299    std::string type() const override { return "Identifier"; }
300    value execute_impl(context & ctx) override;
301};
302
303// Literals
304
305struct integer_literal : public expression {
306    int64_t val;
307    explicit integer_literal(int64_t val) : val(val) {}
308    std::string type() const override { return "IntegerLiteral"; }
309    value execute_impl(context &) override {
310        return mk_val<value_int>(val);
311    }
312};
313
314struct float_literal : public expression {
315    double val;
316    explicit float_literal(double val) : val(val) {}
317    std::string type() const override { return "FloatLiteral"; }
318    value execute_impl(context &) override {
319        return mk_val<value_float>(val);
320    }
321};
322
323struct string_literal : public expression {
324    std::string val;
325    explicit string_literal(const std::string & val) : val(val) {}
326    std::string type() const override { return "StringLiteral"; }
327    value execute_impl(context &) override {
328        return mk_val<value_string>(val);
329    }
330};
331
332struct array_literal : public expression {
333    statements val;
334    explicit array_literal(statements && val) : val(std::move(val)) {
335        for (const auto& item : this->val) chk_type<expression>(item);
336    }
337    std::string type() const override { return "ArrayLiteral"; }
338    value execute_impl(context & ctx) override {
339        auto arr = mk_val<value_array>();
340        for (const auto & item_stmt : val) {
341            arr->push_back(item_stmt->execute(ctx));
342        }
343        return arr;
344    }
345};
346
347struct tuple_literal : public expression {
348    statements val;
349    explicit tuple_literal(statements && val) : val(std::move(val)) {
350        for (const auto& item : this->val) chk_type<expression>(item);
351    }
352    std::string type() const override { return "TupleLiteral"; }
353    value execute_impl(context & ctx) override {
354        auto arr = mk_val<value_array>();
355        for (const auto & item_stmt : val) {
356            arr->push_back(item_stmt->execute(ctx));
357        }
358        return mk_val<value_tuple>(std::move(arr->as_array()));
359    }
360};
361
362struct object_literal : public expression {
363    std::vector<std::pair<statement_ptr, statement_ptr>> val;
364    explicit object_literal(std::vector<std::pair<statement_ptr, statement_ptr>> && val)
365        : val(std::move(val)) {
366        for (const auto & pair : this->val) {
367            chk_type<expression>(pair.first);
368            chk_type<expression>(pair.second);
369        }
370    }
371    std::string type() const override { return "ObjectLiteral"; }
372    value execute_impl(context & ctx) override;
373};
374
375// Complex Expressions
376
377/**
378 * An operation with two sides, separated by an operator.
379 * Note: Either side can be a Complex Expression, with order
380 * of operations being determined by the operator.
381 */
382struct binary_expression : public expression {
383    token op;
384    statement_ptr left;
385    statement_ptr right;
386
387    binary_expression(token op, statement_ptr && left, statement_ptr && right)
388        : op(std::move(op)), left(std::move(left)), right(std::move(right)) {
389        chk_type<expression>(this->left);
390        chk_type<expression>(this->right);
391    }
392    std::string type() const override { return "BinaryExpression"; }
393    value execute_impl(context & ctx) override;
394};
395
396/**
397 * An operation with two sides, separated by the | operator.
398 * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202
399 */
400struct filter_expression : public expression {
401    // either an expression or a value is allowed
402    statement_ptr operand;
403    value_string val; // will be set by filter_statement
404
405    statement_ptr filter;
406
407    filter_expression(statement_ptr && operand, statement_ptr && filter)
408        : operand(std::move(operand)), filter(std::move(filter)) {
409        chk_type<expression>(this->operand);
410        chk_type<identifier, call_expression>(this->filter);
411    }
412
413    filter_expression(value_string && val, statement_ptr && filter)
414        : val(std::move(val)), filter(std::move(filter)) {
415        chk_type<identifier, call_expression>(this->filter);
416    }
417
418    std::string type() const override { return "FilterExpression"; }
419    value execute_impl(context & ctx) override;
420};
421
422struct filter_statement : public statement {
423    statement_ptr filter;
424    statements body;
425
426    filter_statement(statement_ptr && filter, statements && body)
427        : filter(std::move(filter)), body(std::move(body)) {
428        chk_type<identifier, call_expression>(this->filter);
429    }
430    std::string type() const override { return "FilterStatement"; }
431    value execute_impl(context & ctx) override;
432};
433
434/**
435 * An operation which filters a sequence of objects by applying a test to each object,
436 * and only selecting the objects with the test succeeding.
437 *
438 * It may also be used as a shortcut for a ternary operator.
439 */
440struct select_expression : public expression {
441    statement_ptr lhs;
442    statement_ptr test;
443
444    select_expression(statement_ptr && lhs, statement_ptr && test)
445        : lhs(std::move(lhs)), test(std::move(test)) {
446        chk_type<expression>(this->lhs);
447        chk_type<expression>(this->test);
448    }
449    std::string type() const override { return "SelectExpression"; }
450    value execute_impl(context & ctx) override {
451        auto predicate = test->execute_impl(ctx);
452        if (!predicate->as_bool()) {
453            return mk_val<value_undefined>();
454        }
455        return lhs->execute_impl(ctx);
456    }
457};
458
459/**
460 * An operation with two sides, separated by the "is" operator.
461 * NOTE: "value is something" translates to function call "test_is_something(value)"
462 */
463struct test_expression : public expression {
464    statement_ptr operand;
465    bool negate;
466    statement_ptr test;
467
468    test_expression(statement_ptr && operand, bool negate, statement_ptr && test)
469        : operand(std::move(operand)), negate(negate), test(std::move(test)) {
470        chk_type<expression>(this->operand);
471        chk_type<identifier, call_expression>(this->test);
472    }
473    std::string type() const override { return "TestExpression"; }
474    value execute_impl(context & ctx) override;
475};
476
477/**
478 * An operation with one side (operator on the left).
479 */
480struct unary_expression : public expression {
481    token op;
482    statement_ptr argument;
483
484    unary_expression(token op, statement_ptr && argument)
485        : op(std::move(op)), argument(std::move(argument)) {
486        chk_type<expression>(this->argument);
487    }
488    std::string type() const override { return "UnaryExpression"; }
489    value execute_impl(context & ctx) override;
490};
491
492struct slice_expression : public expression {
493    statement_ptr start_expr;
494    statement_ptr stop_expr;
495    statement_ptr step_expr;
496
497    slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr)
498        : start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) {
499        chk_type<expression>(this->start_expr);
500        chk_type<expression>(this->stop_expr);
501        chk_type<expression>(this->step_expr);
502    }
503    std::string type() const override { return "SliceExpression"; }
504    value execute_impl(context &) override {
505        throw std::runtime_error("must be handled by MemberExpression");
506    }
507};
508
509struct keyword_argument_expression : public expression {
510    statement_ptr key;
511    statement_ptr val;
512
513    keyword_argument_expression(statement_ptr && key, statement_ptr && val)
514        : key(std::move(key)), val(std::move(val)) {
515        chk_type<identifier>(this->key);
516        chk_type<expression>(this->val);
517    }
518    std::string type() const override { return "KeywordArgumentExpression"; }
519    value execute_impl(context & ctx) override;
520};
521
522struct spread_expression : public expression {
523    statement_ptr argument;
524    explicit spread_expression(statement_ptr && argument) : argument(std::move(argument)) {
525        chk_type<expression>(this->argument);
526    }
527    std::string type() const override { return "SpreadExpression"; }
528};
529
530struct call_statement : public statement {
531    statement_ptr call;
532    statements caller_args;
533    statements body;
534
535    call_statement(statement_ptr && call, statements && caller_args, statements && body)
536        : call(std::move(call)), caller_args(std::move(caller_args)), body(std::move(body)) {
537        chk_type<call_expression>(this->call);
538        for (const auto & arg : this->caller_args) chk_type<expression>(arg);
539    }
540    std::string type() const override { return "CallStatement"; }
541};
542
543struct ternary_expression : public expression {
544    statement_ptr condition;
545    statement_ptr true_expr;
546    statement_ptr false_expr;
547
548    ternary_expression(statement_ptr && condition, statement_ptr && true_expr, statement_ptr && false_expr)
549        : condition(std::move(condition)), true_expr(std::move(true_expr)), false_expr(std::move(false_expr)) {
550        chk_type<expression>(this->condition);
551        chk_type<expression>(this->true_expr);
552        chk_type<expression>(this->false_expr);
553    }
554    std::string type() const override { return "Ternary"; }
555    value execute_impl(context & ctx) override {
556        value cond_val = condition->execute(ctx);
557        if (cond_val->as_bool()) {
558            return true_expr->execute(ctx);
559        } else {
560            return false_expr->execute(ctx);
561        }
562    }
563};
564
565struct raised_exception : public std::exception {
566    std::string message;
567    raised_exception(const std::string & msg) : message(msg) {}
568    const char* what() const noexcept override {
569        return message.c_str();
570    }
571};
572
573// Used to rethrow exceptions with modified messages
574struct rethrown_exception : public std::exception {
575    std::string message;
576    rethrown_exception(const std::string & msg) : message(msg) {}
577    const char* what() const noexcept override {
578        return message.c_str();
579    }
580};
581
582//////////////////////
583
584static void gather_string_parts_recursive(const value & val, value_string & parts) {
585    // TODO: probably allow print value_none as "None" string? currently this breaks some templates
586    if (is_val<value_string>(val)) {
587        const auto & str_val = cast_val<value_string>(val)->val_str;
588        parts->val_str.append(str_val);
589    } else if (is_val<value_int>(val) || is_val<value_float>(val) || is_val<value_bool>(val)) {
590        std::string str_val = val->as_string().str();
591        parts->val_str.append(str_val);
592    } else if (is_val<value_array>(val)) {
593        auto items = cast_val<value_array>(val)->as_array();
594        for (const auto & item : items) {
595            gather_string_parts_recursive(item, parts);
596        }
597    }
598}
599
600static std::string render_string_parts(const value_string & parts) {
601    std::ostringstream oss;
602    for (const auto & part : parts->val_str.parts) {
603        oss << part.val;
604    }
605    return oss.str();
606}
607
608struct runtime {
609    context & ctx;
610    explicit runtime(context & ctx) : ctx(ctx) {}
611
612    value_array execute(const program & prog) {
613        value_array results = mk_val<value_array>();
614        for (const auto & stmt : prog.body) {
615            value res = stmt->execute(ctx);
616            results->push_back(std::move(res));
617        }
618        return results;
619    }
620
621    static value_string gather_string_parts(const value & val) {
622        value_string parts = mk_val<value_string>();
623        gather_string_parts_recursive(val, parts);
624        // join consecutive parts with the same type
625        auto & p = parts->val_str.parts;
626        for (size_t i = 1; i < p.size(); ) {
627            if (p[i].is_input == p[i - 1].is_input) {
628                p[i - 1].val += p[i].val;
629                p.erase(p.begin() + i);
630            } else {
631                i++;
632            }
633        }
634        return parts;
635    }
636};
637
638} // namespace jinja