1#include "lexer.h"
2#include "runtime.h"
3
4#include <cctype>
5#include <functional>
6#include <map>
7#include <string>
8#include <vector>
9
10#define FILENAME "jinja-lexer"
11
12namespace jinja {
13
14static void string_lstrip(std::string & s, const char * chars) {
15 size_t start = s.find_first_not_of(chars);
16 if (start == std::string::npos) {
17 s.clear();
18 } else {
19 s.erase(0, start);
20 }
21}
22
23static void string_rstrip(std::string & s, const char * chars) {
24 size_t end = s.find_last_not_of(chars);
25 if (end == std::string::npos) {
26 s.clear();
27 } else {
28 s.erase(end + 1);
29 }
30}
31
32lexer_result lexer::tokenize(const std::string & source) {
33 std::vector<token> tokens;
34
35 // NOTE: do NOT transform the source string (i.e. preprocessing), as we need to keep
36 // the original character positions for error reporting etc.
37 std::string src = source;
38
39 if (source.empty()) {
40 return {tokens, src};
41 }
42
43 // Normalize \r\n or \r to \n
44 for (std::string::size_type pos = 0; (pos = src.find("\r\n", pos)) != std::string::npos; ) {
45 src.erase(pos, 1);
46 ++pos;
47 }
48 for (std::string::size_type pos = 0; (pos = src.find("\r", pos)) != std::string::npos; ) {
49 src.replace(pos, 1, 1, '\n');
50 ++pos;
51 }
52
53 // In the default configuration:
54 // - a single trailing newline is stripped if present
55 // - other whitespace (spaces, tabs, newlines etc.) is returned unchanged
56 if (source.back() == '\n') {
57 src.pop_back();
58 }
59
60 size_t pos = 0;
61 size_t start_pos = 0;
62 size_t curly_bracket_depth = 0;
63
64 using pred = std::function<bool(char)>;
65 auto consume_while = [&](const pred & predicate) -> std::string {
66 std::string str;
67 while (predicate(src[pos])) {
68 // check for escape char
69 if (src[pos] == '\\') {
70 // consume backslash
71 ++pos;
72 // check for end of input
73 if (pos >= src.size()) {
74 throw lexer_exception("unexpected end of input after escape character", source, pos);
75 }
76 // add escaped char
77 char escaped_char = src[pos++];
78 if (escape_chars.find(escaped_char) == escape_chars.end()) {
79 throw lexer_exception(std::string("unknown escape character \\") + escaped_char, source, pos);
80 }
81 char unescaped_char = escape_chars.at(escaped_char);
82 str += unescaped_char;
83 continue;
84 }
85
86 str += src[pos++];
87 if (pos > src.size()) {
88 throw lexer_exception("unexpected end of input during consume_while", source, pos);
89 }
90 }
91 return str;
92 };
93
94 auto consume_numeric = [&]() -> std::string {
95 std::string num = consume_while(is_integer);
96 if (pos < src.size() && src[pos] == '.' && pos + 1 < src.size() && is_integer(src[pos + 1])) {
97 ++pos; // Consume '.'
98 std::string frac = consume_while(is_integer);
99 num += "." + frac;
100 }
101 return num;
102 };
103
104 auto next_pos_is = [&](std::initializer_list<char> chars, size_t n = 1) -> bool {
105 if (pos + n >= src.size()) return false;
106 for (char c : chars) {
107 if (src[pos + n] == c) return true;
108 }
109 return false;
110 };
111
112 // note: default config for chat template: lstrip_blocks = true, trim_blocks = true
113
114 // text\n[space]{block} --> text\n{block}
115 bool opt_lstrip_blocks = true;
116
117 // {block}\n[space]text --> {block}[space]text
118 bool opt_trim_blocks = true;
119
120 // options set dynamically based on current/last block
121 bool is_lstrip_block = false; // example: {%-
122 bool is_rstrip_block = false; // example: -%}
123
124 while (pos < src.size()) {
125 start_pos = pos;
126 // JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
127
128 // First, consume all text that is outside of a Jinja statement or expression
129 token::type last_token_type = tokens.empty()
130 ? token::close_statement // initial state
131 : tokens.back().t;
132 if (last_token_type == token::close_statement ||
133 last_token_type == token::close_expression ||
134 last_token_type == token::comment) {
135
136 bool last_block_can_rm_newline = false;
137 is_rstrip_block = false;
138 if (pos > 3) {
139 char c0 = src[pos - 3];
140 char c1 = src[pos - 2];
141 char c2 = src[pos - 1];
142 // strip if: -[%}#]}text
143 is_rstrip_block = c0 == '-'
144 && (c1 == '%' || c1 == '}' || c1 == '#')
145 && c2 == '}';
146 // match behavior of hf.js: exclude {{ and }} cases, regex: ([#%-]})
147 last_block_can_rm_newline = (c1 == '#' || c1 == '%' || c1 == '-') && c2 == '}';
148 }
149
150 size_t start = pos;
151 size_t end = start;
152 while (pos < src.size() &&
153 // Keep going until we hit the next Jinja statement or expression
154 !(
155 src[pos] == '{' &&
156 next_pos_is( {'%', '{', '#'} )
157 )) {
158 end = ++pos;
159 }
160
161 // equivalent to hf.js code: template.replace(/^[ \t]*({[#%-])/gm, "$1");
162 if (opt_lstrip_blocks && src[pos] == '{' && next_pos_is({'%', '#', '-'})) {
163 size_t current = end;
164 while (current > start) {
165 char c = src[current - 1];
166 if (current == 1) {
167 end = 0; // Trim from the start of the string
168 break;
169 }
170 if (c == '\n') {
171 end = current; // Trim from the start of the line
172 break;
173 }
174 if (!std::isspace(static_cast<unsigned char>(c))) {
175 break; // Found non-whitespace before newline, keep
176 }
177 --current;
178 }
179 }
180
181 std::string text = src.substr(start, end - start);
182
183 // equivalent to hf.js code: template.replace(/([#%-]})\n/g, "$1");
184 if (opt_trim_blocks && last_block_can_rm_newline) {
185 if (!text.empty() && text.front() == '\n') {
186 text.erase(text.begin());
187 }
188 }
189
190 if (is_rstrip_block) {
191 // example: {last_block}[space]text
192 // doing lstrip on text, effectively rstrip the LAST block
193 // JJ_DEBUG("RSTRIP block detected, current text: '%s'", text.c_str());
194 string_lstrip(text, " \t\r\n");
195 }
196
197 is_lstrip_block = src[pos] == '{' && next_pos_is({'{', '%', '#'}) && next_pos_is({'-'}, 2);
198 if (is_lstrip_block) {
199 // example: text[space]{current_block}
200 // doing rstrip on text, effectively lstrip the CURRENT block
201 // JJ_DEBUG("LSTRIP block detected, current text: '%s'", text.c_str());
202 string_rstrip(text, " \t\r\n");
203 }
204
205 if (!text.empty()) {
206 // JJ_DEBUG("consumed text: '%s'", text.c_str());
207 tokens.push_back({token::text, text, start_pos});
208 continue;
209 }
210 }
211
212 // Possibly consume a comment
213 // TODO: handle lstrip/rstrip for comments? (not important for now)
214 if (src[pos] == '{' && next_pos_is( {'#'} )) {
215 start_pos = pos;
216 pos += 2; // Skip the opening {#
217 std::string comment;
218 while (!(src[pos] == '#' && next_pos_is( {'}'} ))) {
219 if (pos + 2 >= src.size()) {
220 throw lexer_exception("missing end of comment tag", source, pos);
221 }
222 comment += src[pos++];
223 }
224 JJ_DEBUG("consumed comment: '%s'", comment.c_str());
225 tokens.push_back({token::comment, comment, start_pos});
226 pos += 2; // Skip the closing #}
227 continue;
228 }
229
230 if (src[pos] == '-' && (
231 last_token_type == token::open_expression ||
232 last_token_type == token::open_statement)
233 ) {
234 JJ_DEBUG("lexer main loop at pos %zu: '%s...'", pos, src.substr(pos, 10).c_str());
235 pos++; // consume '-' in {%- or {{-
236 if (pos >= src.size()) break;
237 }
238
239 // Consume (and ignore) all whitespace inside Jinja statements or expressions
240 consume_while([](char c) { return std::isspace(static_cast<unsigned char>(c)); });
241
242 if (pos >= src.size()) break;
243
244 char ch = src[pos];
245
246 bool is_closing_block = ch == '-' && next_pos_is( {'%', '}'} );
247
248 // Check for unary operators
249 if (!is_closing_block && (ch == '-' || ch == '+')) {
250 start_pos = pos;
251 token::type last_token_type = tokens.empty() ? token::eof : tokens.back().t;
252 if (last_token_type == token::text || last_token_type == token::eof) {
253 throw lexer_exception(std::string("unexpected character: ") + ch, source, pos);
254 }
255 switch (last_token_type) {
256 case token::identifier:
257 case token::numeric_literal:
258 case token::string_literal:
259 case token::close_paren:
260 case token::close_square_bracket:
261 // Part of a binary operator
262 // a - 1, 1 - 1, true - 1, "apple" - 1, (1) - 1, a[1] - 1
263 // Continue parsing normally
264 break;
265 default: {
266 // Is part of a unary operator
267 // (-1), [-1], (1 + -1), not -1, -apple
268 ++pos; // Consume the operator
269
270 // Check for numbers following the unary operator
271 std::string num = consume_numeric();
272 std::string value = std::string(1, ch) + num;
273 token::type t = num.empty() ? token::unary_operator : token::numeric_literal;
274 // JJ_DEBUG("consumed unary operator or numeric literal: '%s'", value.c_str());
275 tokens.push_back({t, value, start_pos});
276 continue;
277 }
278 }
279 }
280
281 // Try to match one of the tokens in the mapping table
282 bool matched = false;
283 for (const auto & [seq, typ] : ordered_mapping_table) {
284 start_pos = pos;
285 // Inside an object literal, don't treat "}}" as expression-end
286 if (seq == "}}" && curly_bracket_depth > 0) {
287 continue;
288 }
289 if (pos + seq.size() <= src.size() && src.substr(pos, seq.size()) == seq) {
290 tokens.push_back({typ, seq, start_pos});
291 if (typ == token::open_expression) {
292 curly_bracket_depth = 0;
293 } else if (typ == token::open_curly_bracket) {
294 ++curly_bracket_depth;
295 } else if (typ == token::close_curly_bracket) {
296 --curly_bracket_depth;
297 }
298
299 pos += seq.size();
300 matched = true;
301 break; // continue main loop
302 }
303 }
304 if (matched) continue; // continue main loop
305
306 // Strings
307 if (ch == '\'' || ch == '"') {
308 start_pos = pos;
309 ++pos; // Skip opening quote
310 std::string str = consume_while([ch](char c) { return c != ch; });
311 // JJ_DEBUG("consumed string literal: '%s'", str.c_str());
312 tokens.push_back({token::string_literal, str, start_pos});
313 ++pos; // Skip closing quote
314 continue;
315 }
316
317 // Numbers
318 if (is_integer(ch)) {
319 start_pos = pos;
320 std::string num = consume_numeric();
321 // JJ_DEBUG("consumed numeric literal: '%s'", num.c_str());
322 tokens.push_back({token::numeric_literal, num, start_pos});
323 continue;
324 }
325
326 // Identifiers
327 if (is_word(ch)) {
328 start_pos = pos;
329 std::string word = consume_while(is_word);
330 // JJ_DEBUG("consumed identifier: '%s'", word.c_str());
331 tokens.push_back({token::identifier, word, start_pos});
332 continue;
333 }
334
335 throw lexer_exception(std::string("unexpected character: ") + ch, source, pos);
336 }
337
338 return {std::move(tokens), src};
339}
340
341} // namespace jinja