summaryrefslogtreecommitdiff
path: root/llama.cpp/common/jinja/lexer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/common/jinja/lexer.cpp')
-rw-r--r--llama.cpp/common/jinja/lexer.cpp341
1 files changed, 341 insertions, 0 deletions
diff --git a/llama.cpp/common/jinja/lexer.cpp b/llama.cpp/common/jinja/lexer.cpp
new file mode 100644
index 0000000..598982c
--- /dev/null
+++ b/llama.cpp/common/jinja/lexer.cpp
@@ -0,0 +1,341 @@
+#include "lexer.h"
+#include "runtime.h"
+
+#include <cctype>
+#include <functional>
+#include <map>
+#include <string>
+#include <vector>
+
+#define FILENAME "jinja-lexer"
+
+namespace jinja {
+
+static void string_lstrip(std::string & s, const char * chars) {
+ size_t start = s.find_first_not_of(chars);
+ if (start == std::string::npos) {
+ s.clear();
+ } else {
+ s.erase(0, start);
+ }
+}
+
+static void string_rstrip(std::string & s, const char * chars) {
+ size_t end = s.find_last_not_of(chars);
+ if (end == std::string::npos) {
+ s.clear();
+ } else {
+ s.erase(end + 1);
+ }
+}
+
+lexer_result lexer::tokenize(const std::string & source) {
+ std::vector<token> tokens;
+
+ // NOTE: do NOT transform the source string (i.e. preprocessing), as we need to keep
+ // the original character positions for error reporting etc.
+ std::string src = source;
+
+ if (source.empty()) {
+ return {tokens, src};
+ }
+
+ // Normalize \r\n or \r to \n
+ for (std::string::size_type pos = 0; (pos = src.find("\r\n", pos)) != std::string::npos; ) {
+ src.erase(pos, 1);
+ ++pos;
+ }
+ for (std::string::size_type pos = 0; (pos = src.find("\r", pos)) != std::string::npos; ) {
+ src.replace(pos, 1, 1, '\n');
+ ++pos;
+ }
+
+ // In the default configuration:
+ // - a single trailing newline is stripped if present
+ // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged
+ if (source.back() == '\n') {
+ src.pop_back();
+ }
+
+ size_t pos = 0;
+ size_t start_pos = 0;
+ size_t curly_bracket_depth = 0;
+
+ using pred = std::function<bool(char)>;
+ auto consume_while = [&](const pred & predicate) -> std::string {
+ std::string str;
+ while (predicate(src[pos])) {
+ // check for escape char
+ if (src[pos] == '\\') {
+ // consume backslash
+ ++pos;
+ // check for end of input
+ if (pos >= src.size()) {
+ throw lexer_exception("unexpected end of input after escape character", source, pos);
+ }
+ // add escaped char
+ char escaped_char = src[pos++];
+ if (escape_chars.find(escaped_char) == escape_chars.end()) {
+ throw lexer_exception(std::string("unknown escape character \\") + escaped_char, source, pos);
+ }
+ char unescaped_char = escape_chars.at(escaped_char);
+ str += unescaped_char;
+ continue;
+ }
+
+ str += src[pos++];
+ if (pos > src.size()) {
+ throw lexer_exception("unexpected end of input during consume_while", source, pos);
+ }
+ }
+ return str;
+ };
+
+ auto consume_numeric = [&]() -> std::string {
+ std::string num = consume_while(is_integer);
+ if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
+ ++pos; // Consume '.'
+ std::string frac = consume_while(is_integer);
+ num += "." + frac;
+ }
+ return num;
+ };
+
+ auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
+ if (pos + n >= src.size()) return false;
+ for (char c : chars) {
+ if (src[pos + n] == c) return true;
+ }
+ return false;
+ };
+
+ // note: default config for chat template: lstrip_blocks = true, trim_blocks = true
+
+ // text\n[space]{block} --> text\n{block}
+ bool opt_lstrip_blocks = true;
+
+ // {block}\n[space]text --> {block}[space]text
+ bool opt_trim_blocks = true;
+
+ // options set dynamically based on current/last block
+ bool is_lstrip_block = false; // example: {%-
+ bool is_rstrip_block = false; // example: -%}
+
+ while (pos < src.size()) {
+ start_pos = pos;
+ // JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
+
+ // First, consume all text that is outside of a Jinja statement or expression
+ token::type last_token_type = tokens.empty()
+ ? token::close_statement // initial state
+ : tokens.back().t;
+ if (last_token_type == token::close_statement ||
+ last_token_type == token::close_expression ||
+ last_token_type == token::comment) {
+
+ bool last_block_can_rm_newline = false;
+ is_rstrip_block = false;
+ if (pos > 3) {
+ char c0 = src[pos - 3];
+ char c1 = src[pos - 2];
+ char c2 = src[pos - 1];
+ // strip if: -[%}#]}text
+ is_rstrip_block = c0 == '-'
+ && (c1 == '%' || c1 == '}' || c1 == '#')
+ && c2 == '}';
+ // match behavior of hf.js: exclude {{ and }} cases, regex: ([#%-]})
+ last_block_can_rm_newline = (c1 == '#' || c1 == '%' || c1 == '-') && c2 == '}';
+ }
+
+ size_t start = pos;
+ size_t end = start;
+ while (pos < src.size() &&
+ // Keep going until we hit the next Jinja statement or expression
+ !(
+ src[pos] == '{' &&
+ next_pos_is( {'%', '{', '#'} )
+ )) {
+ end = ++pos;
+ }
+
+ // equivalent to hf.js code: template.replace(/^[ \t]*({[#%-])/gm, "$1");
+ if (opt_lstrip_blocks && src[pos] == '{' && next_pos_is({'%', '#', '-'})) {
+ size_t current = end;
+ while (current > start) {
+ char c = src[current - 1];
+ if (current == 1) {
+ end = 0; // Trim from the start of the string
+ break;
+ }
+ if (c == '\n') {
+ end = current; // Trim from the start of the line
+ break;
+ }
+ if (!std::isspace(static_cast<unsigned char>(c))) {
+ break; // Found non-whitespace before newline, keep
+ }
+ --current;
+ }
+ }
+
+ std::string text = src.substr(start, end - start);
+
+ // equivalent to hf.js code: template.replace(/([#%-]})\n/g, "$1");
+ if (opt_trim_blocks && last_block_can_rm_newline) {
+ if (!text.empty() && text.front() == '\n') {
+ text.erase(text.begin());
+ }
+ }
+
+ if (is_rstrip_block) {
+ // example: {last_block}[space]text
+ // doing lstrip on text, effectively rstrip the LAST block
+ // JJ_DEBUG("RSTRIP block detected, current text: '%s'", text.c_str());
+ string_lstrip(text, " \t\r\n");
+ }
+
+ is_lstrip_block = src[pos] == '{' && next_pos_is({'{', '%', '#'}) && next_pos_is({'-'}, 2);
+ if (is_lstrip_block) {
+ // example: text[space]{current_block}
+ // doing rstrip on text, effectively lstrip the CURRENT block
+ // JJ_DEBUG("LSTRIP block detected, current text: '%s'", text.c_str());
+ string_rstrip(text, " \t\r\n");
+ }
+
+ if (!text.empty()) {
+ // JJ_DEBUG("consumed text: '%s'", text.c_str());
+ tokens.push_back({token::text, text, start_pos});
+ continue;
+ }
+ }
+
+ // Possibly consume a comment
+ // TODO: handle lstrip/rstrip for comments? (not important for now)
+ if (src[pos] == '{' && next_pos_is( {'#'} )) {
+ start_pos = pos;
+ pos += 2; // Skip the opening {#
+ std::string comment;
+ while (!(src[pos] == '#' && next_pos_is( {'}'} ))) {
+ if (pos + 2 >= src.size()) {
+ throw lexer_exception("missing end of comment tag", source, pos);
+ }
+ comment += src[pos++];
+ }
+ JJ_DEBUG("consumed comment: '%s'", comment.c_str());
+ tokens.push_back({token::comment, comment, start_pos});
+ pos += 2; // Skip the closing #}
+ continue;
+ }
+
+ if (src[pos] == '-' && (
+ last_token_type == token::open_expression ||
+ last_token_type == token::open_statement)
+ ) {
+ JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
+ pos++; // consume '-' in {%- or {{-
+ if (pos >= src.size()) break;
+ }
+
+ // Consume (and ignore) all whitespace inside Jinja statements or expressions
+ consume_while([](char c) { return std::isspace(static_cast<unsigned char>(c)); });
+
+ if (pos >= src.size()) break;
+
+ char ch = src[pos];
+
+ bool is_closing_block = ch == '-' && next_pos_is( {'%', '}'} );
+
+ // Check for unary operators
+ if (!is_closing_block && (ch == '-' || ch == '+')) {
+ start_pos = pos;
+ token::type last_token_type = tokens.empty() ? token::eof : tokens.back().t;
+ if (last_token_type == token::text || last_token_type == token::eof) {
+ throw lexer_exception(std::string("unexpected character: ") + ch, source, pos);
+ }
+ switch (last_token_type) {
+ case token::identifier:
+ case token::numeric_literal:
+ case token::string_literal:
+ case token::close_paren:
+ case token::close_square_bracket:
+ // Part of a binary operator
+ // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1
+ // Continue parsing normally
+ break;
+ default: {
+ // Is part of a unary operator
+ // (-1), [-1], (1 + -1), not -1, -apple
+ ++pos; // Consume the operator
+
+ // Check for numbers following the unary operator
+ std::string num = consume_numeric();
+ std::string value = std::string(1, ch) + num;
+ token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
+ // JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
+ tokens.push_back({t, value, start_pos});
+ continue;
+ }
+ }
+ }
+
+ // Try to match one of the tokens in the mapping table
+ bool matched = false;
+ for (const auto & [seq, typ] : ordered_mapping_table) {
+ start_pos = pos;
+ // Inside an object literal, don't treat "}}" as expression-end
+ if (seq == "}}" && curly_bracket_depth > 0) {
+ continue;
+ }
+ if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) {
+ tokens.push_back({typ, seq, start_pos});
+ if (typ == token::open_expression) {
+ curly_bracket_depth = 0;
+ } else if (typ == token::open_curly_bracket) {
+ ++curly_bracket_depth;
+ } else if (typ == token::close_curly_bracket) {
+ --curly_bracket_depth;
+ }
+
+ pos += seq.size();
+ matched = true;
+ break; // continue main loop
+ }
+ }
+ if (matched) continue; // continue main loop
+
+ // Strings
+ if (ch == '\'' || ch == '"') {
+ start_pos = pos;
+ ++pos; // Skip opening quote
+ std::string str = consume_while([ch](char c) { return c != ch; });
+ // JJ_DEBUG("consumed string literal: '%s'", str.c_str());
+ tokens.push_back({token::string_literal, str, start_pos});
+ ++pos; // Skip closing quote
+ continue;
+ }
+
+ // Numbers
+ if (is_integer(ch)) {
+ start_pos = pos;
+ std::string num = consume_numeric();
+ // JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
+ tokens.push_back({token::numeric_literal, num, start_pos});
+ continue;
+ }
+
+ // Identifiers
+ if (is_word(ch)) {
+ start_pos = pos;
+ std::string word = consume_while(is_word);
+ // JJ_DEBUG("consumed identifier: '%s'", word.c_str());
+ tokens.push_back({token::identifier, word, start_pos});
+ continue;
+ }
+
+ throw lexer_exception(std::string("unexpected character: ") + ch, source, pos);
+ }
+
+ return {std::move(tokens), src};
+}
+
+} // namespace jinja