diff options
Diffstat (limited to 'llama.cpp/common/jinja/runtime.h')
| -rw-r--r-- | llama.cpp/common/jinja/runtime.h | 638 |
1 files changed, 638 insertions, 0 deletions
diff --git a/llama.cpp/common/jinja/runtime.h b/llama.cpp/common/jinja/runtime.h new file mode 100644 index 0000000..17a6dff --- /dev/null +++ b/llama.cpp/common/jinja/runtime.h @@ -0,0 +1,638 @@ +#pragma once + +#include "lexer.h" +#include "value.h" + +#include <cassert> +#include <ctime> +#include <memory> +#include <sstream> +#include <string> +#include <vector> + +#define JJ_DEBUG(msg, ...) do { if (g_jinja_debug) printf("%s:%-3d : " msg "\n", FILENAME, __LINE__, __VA_ARGS__); } while (0) + +extern bool g_jinja_debug; + +namespace jinja { + +struct statement; +using statement_ptr = std::unique_ptr<statement>; +using statements = std::vector<statement_ptr>; + +// Helpers for dynamic casting and type checking +template<typename T> +struct extract_pointee_unique { + using type = T; +}; +template<typename U> +struct extract_pointee_unique<std::unique_ptr<U>> { + using type = U; +}; +template<typename T> +bool is_stmt(const statement_ptr & ptr) { + return dynamic_cast<const T*>(ptr.get()) != nullptr; +} +template<typename T> +T * cast_stmt(statement_ptr & ptr) { + return dynamic_cast<T*>(ptr.get()); +} +template<typename T> +const T * cast_stmt(const statement_ptr & ptr) { + return dynamic_cast<const T*>(ptr.get()); +} +// End Helpers + + +// not thread-safe +void enable_debug(bool enable); + +struct context { + std::shared_ptr<std::string> src; // for debugging; use shared_ptr to avoid copying on scope creation + std::time_t current_time; // for functions that need current time + + bool is_get_stats = false; // whether to collect stats + + // src is optional, used for error reporting + context(std::string src = "") : src(std::make_shared<std::string>(std::move(src))) { + env = mk_val<value_object>(); + env->has_builtins = false; // context object has no builtins + env->insert("true", mk_val<value_bool>(true)); + env->insert("True", mk_val<value_bool>(true)); + env->insert("false", mk_val<value_bool>(false)); + env->insert("False", mk_val<value_bool>(false)); + env->insert("none", mk_val<value_none>()); + env->insert("None", mk_val<value_none>()); + current_time = std::time(nullptr); + } + ~context() = default; + + context(const context & parent) : context() { + // inherit variables (for example, when entering a new scope) + auto & pvar = parent.env->as_ordered_object(); + for (const auto & pair : pvar) { + set_val(pair.first, pair.second); + } + current_time = parent.current_time; + is_get_stats = parent.is_get_stats; + src = parent.src; + } + + value get_val(const std::string & name) { + value default_val = mk_val<value_undefined>(name); + return env->at(name, default_val); + } + + void set_val(const std::string & name, const value & val) { + env->insert(name, val); + } + + void set_val(const value & name, const value & val) { + env->insert(name, val); + } + + void print_vars() const { + printf("Context Variables:\n%s\n", value_to_json(env, 2).c_str()); + } + +private: + value_object env; +}; + +/** + * Base class for all nodes in the AST. + */ +struct statement { + size_t pos; // position in source, for debugging + virtual ~statement() = default; + virtual std::string type() const { return "Statement"; } + // execute_impl must be overridden by derived classes + virtual value execute_impl(context &) { throw std::runtime_error("cannot exec " + type()); } + // execute is the public method to execute a statement with error handling + value execute(context &); +}; + +// Type Checking Utilities + +template<typename T> +static void chk_type(const statement_ptr & ptr) { + if (!ptr) return; // Allow null for optional fields + assert(dynamic_cast<T *>(ptr.get()) != nullptr); +} + +template<typename T, typename U> +static void chk_type(const statement_ptr & ptr) { + if (!ptr) return; + assert(dynamic_cast<T *>(ptr.get()) != nullptr || dynamic_cast<U *>(ptr.get()) != nullptr); +} + +// Base Types + +/** + * Expressions will result in a value at runtime (unlike statements). + */ +struct expression : public statement { + std::string type() const override { return "Expression"; } +}; + +// Statements + +struct program : public statement { + statements body; + + program() = default; + explicit program(statements && body) : body(std::move(body)) {} + std::string type() const override { return "Program"; } + value execute_impl(context &) override { + throw std::runtime_error("Cannot execute program directly, use jinja::runtime instead"); + } +}; + +struct if_statement : public statement { + statement_ptr test; + statements body; + statements alternate; + + if_statement(statement_ptr && test, statements && body, statements && alternate) + : test(std::move(test)), body(std::move(body)), alternate(std::move(alternate)) { + chk_type<expression>(this->test); + } + + std::string type() const override { return "If"; } + value execute_impl(context & ctx) override; +}; + +struct identifier; +struct tuple_literal; + +/** + * Loop over each item in a sequence + * https://jinja.palletsprojects.com/en/3.0.x/templates/#for + */ +struct for_statement : public statement { + statement_ptr loopvar; // Identifier | TupleLiteral + statement_ptr iterable; + statements body; + statements default_block; // if no iteration took place + + for_statement(statement_ptr && loopvar, statement_ptr && iterable, statements && body, statements && default_block) + : loopvar(std::move(loopvar)), iterable(std::move(iterable)), + body(std::move(body)), default_block(std::move(default_block)) { + chk_type<identifier, tuple_literal>(this->loopvar); + chk_type<expression>(this->iterable); + } + + std::string type() const override { return "For"; } + value execute_impl(context & ctx) override; +}; + +struct break_statement : public statement { + std::string type() const override { return "Break"; } + + struct signal : public std::exception { + const char* what() const noexcept override { + return "Break statement executed"; + } + }; + + value execute_impl(context &) override { + throw break_statement::signal(); + } +}; + +struct continue_statement : public statement { + std::string type() const override { return "Continue"; } + + struct signal : public std::exception { + const char* what() const noexcept override { + return "Continue statement executed"; + } + }; + + value execute_impl(context &) override { + throw continue_statement::signal(); + } +}; + +// do nothing +struct noop_statement : public statement { + std::string type() const override { return "Noop"; } + value execute_impl(context &) override { + return mk_val<value_undefined>(); + } +}; + +struct set_statement : public statement { + statement_ptr assignee; + statement_ptr val; + statements body; + + set_statement(statement_ptr && assignee, statement_ptr && value, statements && body) + : assignee(std::move(assignee)), val(std::move(value)), body(std::move(body)) { + chk_type<expression>(this->assignee); + chk_type<expression>(this->val); + } + + std::string type() const override { return "Set"; } + value execute_impl(context & ctx) override; +}; + +struct macro_statement : public statement { + statement_ptr name; + statements args; + statements body; + + macro_statement(statement_ptr && name, statements && args, statements && body) + : name(std::move(name)), args(std::move(args)), body(std::move(body)) { + chk_type<identifier>(this->name); + for (const auto& arg : this->args) chk_type<expression>(arg); + } + + std::string type() const override { return "Macro"; } + value execute_impl(context & ctx) override; +}; + +struct comment_statement : public statement { + std::string val; + explicit comment_statement(const std::string & v) : val(v) {} + std::string type() const override { return "Comment"; } + value execute_impl(context &) override { + return mk_val<value_undefined>(); + } +}; + +// Expressions + +struct member_expression : public expression { + statement_ptr object; + statement_ptr property; + bool computed; // true if obj[expr] and false if obj.prop + + member_expression(statement_ptr && object, statement_ptr && property, bool computed) + : object(std::move(object)), property(std::move(property)), computed(computed) { + chk_type<expression>(this->object); + chk_type<expression>(this->property); + } + std::string type() const override { return "MemberExpression"; } + value execute_impl(context & ctx) override; +}; + +struct call_expression : public expression { + statement_ptr callee; + statements args; + + call_expression(statement_ptr && callee, statements && args) + : callee(std::move(callee)), args(std::move(args)) { + chk_type<expression>(this->callee); + for (const auto& arg : this->args) chk_type<expression>(arg); + } + std::string type() const override { return "CallExpression"; } + value execute_impl(context & ctx) override; +}; + +/** + * Represents a user-defined variable or symbol in the template. + */ +struct identifier : public expression { + std::string val; + explicit identifier(const std::string & val) : val(val) {} + std::string type() const override { return "Identifier"; } + value execute_impl(context & ctx) override; +}; + +// Literals + +struct integer_literal : public expression { + int64_t val; + explicit integer_literal(int64_t val) : val(val) {} + std::string type() const override { return "IntegerLiteral"; } + value execute_impl(context &) override { + return mk_val<value_int>(val); + } +}; + +struct float_literal : public expression { + double val; + explicit float_literal(double val) : val(val) {} + std::string type() const override { return "FloatLiteral"; } + value execute_impl(context &) override { + return mk_val<value_float>(val); + } +}; + +struct string_literal : public expression { + std::string val; + explicit string_literal(const std::string & val) : val(val) {} + std::string type() const override { return "StringLiteral"; } + value execute_impl(context &) override { + return mk_val<value_string>(val); + } +}; + +struct array_literal : public expression { + statements val; + explicit array_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type<expression>(item); + } + std::string type() const override { return "ArrayLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val<value_array>(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return arr; + } +}; + +struct tuple_literal : public expression { + statements val; + explicit tuple_literal(statements && val) : val(std::move(val)) { + for (const auto& item : this->val) chk_type<expression>(item); + } + std::string type() const override { return "TupleLiteral"; } + value execute_impl(context & ctx) override { + auto arr = mk_val<value_array>(); + for (const auto & item_stmt : val) { + arr->push_back(item_stmt->execute(ctx)); + } + return mk_val<value_tuple>(std::move(arr->as_array())); + } +}; + +struct object_literal : public expression { + std::vector<std::pair<statement_ptr, statement_ptr>> val; + explicit object_literal(std::vector<std::pair<statement_ptr, statement_ptr>> && val) + : val(std::move(val)) { + for (const auto & pair : this->val) { + chk_type<expression>(pair.first); + chk_type<expression>(pair.second); + } + } + std::string type() const override { return "ObjectLiteral"; } + value execute_impl(context & ctx) override; +}; + +// Complex Expressions + +/** + * An operation with two sides, separated by an operator. + * Note: Either side can be a Complex Expression, with order + * of operations being determined by the operator. + */ +struct binary_expression : public expression { + token op; + statement_ptr left; + statement_ptr right; + + binary_expression(token op, statement_ptr && left, statement_ptr && right) + : op(std::move(op)), left(std::move(left)), right(std::move(right)) { + chk_type<expression>(this->left); + chk_type<expression>(this->right); + } + std::string type() const override { return "BinaryExpression"; } + value execute_impl(context & ctx) override; +}; + +/** + * An operation with two sides, separated by the | operator. + * Operator precedence: https://github.com/pallets/jinja/issues/379#issuecomment-168076202 + */ +struct filter_expression : public expression { + // either an expression or a value is allowed + statement_ptr operand; + value_string val; // will be set by filter_statement + + statement_ptr filter; + + filter_expression(statement_ptr && operand, statement_ptr && filter) + : operand(std::move(operand)), filter(std::move(filter)) { + chk_type<expression>(this->operand); + chk_type<identifier, call_expression>(this->filter); + } + + filter_expression(value_string && val, statement_ptr && filter) + : val(std::move(val)), filter(std::move(filter)) { + chk_type<identifier, call_expression>(this->filter); + } + + std::string type() const override { return "FilterExpression"; } + value execute_impl(context & ctx) override; +}; + +struct filter_statement : public statement { + statement_ptr filter; + statements body; + + filter_statement(statement_ptr && filter, statements && body) + : filter(std::move(filter)), body(std::move(body)) { + chk_type<identifier, call_expression>(this->filter); + } + std::string type() const override { return "FilterStatement"; } + value execute_impl(context & ctx) override; +}; + +/** + * An operation which filters a sequence of objects by applying a test to each object, + * and only selecting the objects with the test succeeding. + * + * It may also be used as a shortcut for a ternary operator. + */ +struct select_expression : public expression { + statement_ptr lhs; + statement_ptr test; + + select_expression(statement_ptr && lhs, statement_ptr && test) + : lhs(std::move(lhs)), test(std::move(test)) { + chk_type<expression>(this->lhs); + chk_type<expression>(this->test); + } + std::string type() const override { return "SelectExpression"; } + value execute_impl(context & ctx) override { + auto predicate = test->execute_impl(ctx); + if (!predicate->as_bool()) { + return mk_val<value_undefined>(); + } + return lhs->execute_impl(ctx); + } +}; + +/** + * An operation with two sides, separated by the "is" operator. + * NOTE: "value is something" translates to function call "test_is_something(value)" + */ +struct test_expression : public expression { + statement_ptr operand; + bool negate; + statement_ptr test; + + test_expression(statement_ptr && operand, bool negate, statement_ptr && test) + : operand(std::move(operand)), negate(negate), test(std::move(test)) { + chk_type<expression>(this->operand); + chk_type<identifier, call_expression>(this->test); + } + std::string type() const override { return "TestExpression"; } + value execute_impl(context & ctx) override; +}; + +/** + * An operation with one side (operator on the left). + */ +struct unary_expression : public expression { + token op; + statement_ptr argument; + + unary_expression(token op, statement_ptr && argument) + : op(std::move(op)), argument(std::move(argument)) { + chk_type<expression>(this->argument); + } + std::string type() const override { return "UnaryExpression"; } + value execute_impl(context & ctx) override; +}; + +struct slice_expression : public expression { + statement_ptr start_expr; + statement_ptr stop_expr; + statement_ptr step_expr; + + slice_expression(statement_ptr && start_expr, statement_ptr && stop_expr, statement_ptr && step_expr) + : start_expr(std::move(start_expr)), stop_expr(std::move(stop_expr)), step_expr(std::move(step_expr)) { + chk_type<expression>(this->start_expr); + chk_type<expression>(this->stop_expr); + chk_type<expression>(this->step_expr); + } + std::string type() const override { return "SliceExpression"; } + value execute_impl(context &) override { + throw std::runtime_error("must be handled by MemberExpression"); + } +}; + +struct keyword_argument_expression : public expression { + statement_ptr key; + statement_ptr val; + + keyword_argument_expression(statement_ptr && key, statement_ptr && val) + : key(std::move(key)), val(std::move(val)) { + chk_type<identifier>(this->key); + chk_type<expression>(this->val); + } + std::string type() const override { return "KeywordArgumentExpression"; } + value execute_impl(context & ctx) override; +}; + +struct spread_expression : public expression { + statement_ptr argument; + explicit spread_expression(statement_ptr && argument) : argument(std::move(argument)) { + chk_type<expression>(this->argument); + } + std::string type() const override { return "SpreadExpression"; } +}; + +struct call_statement : public statement { + statement_ptr call; + statements caller_args; + statements body; + + call_statement(statement_ptr && call, statements && caller_args, statements && body) + : call(std::move(call)), caller_args(std::move(caller_args)), body(std::move(body)) { + chk_type<call_expression>(this->call); + for (const auto & arg : this->caller_args) chk_type<expression>(arg); + } + std::string type() const override { return "CallStatement"; } +}; + +struct ternary_expression : public expression { + statement_ptr condition; + statement_ptr true_expr; + statement_ptr false_expr; + + ternary_expression(statement_ptr && condition, statement_ptr && true_expr, statement_ptr && false_expr) + : condition(std::move(condition)), true_expr(std::move(true_expr)), false_expr(std::move(false_expr)) { + chk_type<expression>(this->condition); + chk_type<expression>(this->true_expr); + chk_type<expression>(this->false_expr); + } + std::string type() const override { return "Ternary"; } + value execute_impl(context & ctx) override { + value cond_val = condition->execute(ctx); + if (cond_val->as_bool()) { + return true_expr->execute(ctx); + } else { + return false_expr->execute(ctx); + } + } +}; + +struct raised_exception : public std::exception { + std::string message; + raised_exception(const std::string & msg) : message(msg) {} + const char* what() const noexcept override { + return message.c_str(); + } +}; + +// Used to rethrow exceptions with modified messages +struct rethrown_exception : public std::exception { + std::string message; + rethrown_exception(const std::string & msg) : message(msg) {} + const char* what() const noexcept override { + return message.c_str(); + } +}; + +////////////////////// + +static void gather_string_parts_recursive(const value & val, value_string & parts) { + // TODO: probably allow print value_none as "None" string? currently this breaks some templates + if (is_val<value_string>(val)) { + const auto & str_val = cast_val<value_string>(val)->val_str; + parts->val_str.append(str_val); + } else if (is_val<value_int>(val) || is_val<value_float>(val) || is_val<value_bool>(val)) { + std::string str_val = val->as_string().str(); + parts->val_str.append(str_val); + } else if (is_val<value_array>(val)) { + auto items = cast_val<value_array>(val)->as_array(); + for (const auto & item : items) { + gather_string_parts_recursive(item, parts); + } + } +} + +static std::string render_string_parts(const value_string & parts) { + std::ostringstream oss; + for (const auto & part : parts->val_str.parts) { + oss << part.val; + } + return oss.str(); +} + +struct runtime { + context & ctx; + explicit runtime(context & ctx) : ctx(ctx) {} + + value_array execute(const program & prog) { + value_array results = mk_val<value_array>(); + for (const auto & stmt : prog.body) { + value res = stmt->execute(ctx); + results->push_back(std::move(res)); + } + return results; + } + + static value_string gather_string_parts(const value & val) { + value_string parts = mk_val<value_string>(); + gather_string_parts_recursive(val, parts); + // join consecutive parts with the same type + auto & p = parts->val_str.parts; + for (size_t i = 1; i < p.size(); ) { + if (p[i].is_input == p[i - 1].is_input) { + p[i - 1].val += p[i].val; + p.erase(p.begin() + i); + } else { + i++; + } + } + return parts; + } +}; + +} // namespace jinja |
