1#include "tree_sitter/parser.h"
  2
  3#include <assert.h>
  4#include <stdint.h>
  5#include <stdio.h>
  6#include <string.h>
  7
  8#define MAX(a, b) ((a) > (b) ? (a) : (b))
  9
 10#define VEC_RESIZE(vec, _cap)                                                  \
 11    void *tmp = realloc((vec).data, (_cap) * sizeof((vec).data[0]));           \
 12    assert(tmp != NULL);                                                       \
 13    (vec).data = tmp;                                                          \
 14    (vec).cap = (_cap);
 15
 16#define VEC_GROW(vec, _cap)                                                    \
 17    if ((vec).cap < (_cap)) {                                                  \
 18        VEC_RESIZE((vec), (_cap));                                             \
 19    }
 20
 21#define VEC_PUSH(vec, el)                                                      \
 22    if ((vec).cap == (vec).len) {                                              \
 23        VEC_RESIZE((vec), MAX(16, (vec).len * 2));                             \
 24    }                                                                          \
 25    (vec).data[(vec).len++] = (el);
 26
 27#define VEC_POP(vec) (vec).len--;
 28
 29#define VEC_NEW                                                                \
 30    { .len = 0, .cap = 0, .data = NULL }
 31
 32#define VEC_BACK(vec) ((vec).data[(vec).len - 1])
 33
 34#define VEC_FREE(vec)                                                          \
 35    {                                                                          \
 36        if ((vec).data != NULL)                                                \
 37            free((vec).data);                                                  \
 38    }
 39
 40#define VEC_CLEAR(vec) (vec).len = 0;
 41
 42enum TokenType {
 43    NEWLINE,
 44    INDENT,
 45    DEDENT,
 46    STRING_START,
 47    STRING_CONTENT,
 48    ESCAPE_INTERPOLATION,
 49    STRING_END,
 50    COMMENT,
 51    CLOSE_PAREN,
 52    CLOSE_BRACKET,
 53    CLOSE_BRACE,
 54    EXCEPT,
 55};
 56
 57typedef enum {
 58    SingleQuote = 1 << 0,
 59    DoubleQuote = 1 << 1,
 60    BackQuote = 1 << 2,
 61    Raw = 1 << 3,
 62    Format = 1 << 4,
 63    Triple = 1 << 5,
 64    Bytes = 1 << 6,
 65} Flags;
 66
 67typedef struct {
 68    char flags;
 69} Delimiter;
 70
 71static inline Delimiter new_delimiter() { return (Delimiter){0}; }
 72
 73static inline bool is_format(Delimiter *delimiter) {
 74    return delimiter->flags & Format;
 75}
 76
 77static inline bool is_raw(Delimiter *delimiter) {
 78    return delimiter->flags & Raw;
 79}
 80
 81static inline bool is_triple(Delimiter *delimiter) {
 82    return delimiter->flags & Triple;
 83}
 84
 85static inline bool is_bytes(Delimiter *delimiter) {
 86    return delimiter->flags & Bytes;
 87}
 88
 89static inline int32_t end_character(Delimiter *delimiter) {
 90    if (delimiter->flags & SingleQuote) {
 91        return '\'';
 92    }
 93    if (delimiter->flags & DoubleQuote) {
 94        return '"';
 95    }
 96    if (delimiter->flags & BackQuote) {
 97        return '`';
 98    }
 99    return 0;
100}
101
102static inline void set_format(Delimiter *delimiter) {
103    delimiter->flags |= Format;
104}
105
106static inline void set_raw(Delimiter *delimiter) { delimiter->flags |= Raw; }
107
108static inline void set_triple(Delimiter *delimiter) {
109    delimiter->flags |= Triple;
110}
111
112static inline void set_bytes(Delimiter *delimiter) {
113    delimiter->flags |= Bytes;
114}
115
116static inline void set_end_character(Delimiter *delimiter, int32_t character) {
117    switch (character) {
118        case '\'':
119            delimiter->flags |= SingleQuote;
120            break;
121        case '"':
122            delimiter->flags |= DoubleQuote;
123            break;
124        case '`':
125            delimiter->flags |= BackQuote;
126            break;
127        default:
128            assert(false);
129    }
130}
131
132typedef struct {
133    uint32_t len;
134    uint32_t cap;
135    uint16_t *data;
136} indent_vec;
137
138static indent_vec indent_vec_new() {
139    indent_vec vec = VEC_NEW;
140    vec.data = calloc(1, sizeof(uint16_t));
141    vec.cap = 1;
142    return vec;
143}
144
145typedef struct {
146    uint32_t len;
147    uint32_t cap;
148    Delimiter *data;
149} delimiter_vec;
150
151static delimiter_vec delimiter_vec_new() {
152    delimiter_vec vec = VEC_NEW;
153    vec.data = calloc(1, sizeof(Delimiter));
154    vec.cap = 1;
155    return vec;
156}
157
158typedef struct {
159    indent_vec indents;
160    delimiter_vec delimiters;
161    bool inside_f_string;
162} Scanner;
163
164static inline void advance(TSLexer *lexer) { lexer->advance(lexer, false); }
165
166static inline void skip(TSLexer *lexer) { lexer->advance(lexer, true); }
167
168bool tree_sitter_python_external_scanner_scan(void *payload, TSLexer *lexer,
169                                              const bool *valid_symbols) {
170    Scanner *scanner = (Scanner *)payload;
171
172    bool error_recovery_mode =
173        valid_symbols[STRING_CONTENT] && valid_symbols[INDENT];
174    bool within_brackets = valid_symbols[CLOSE_BRACE] ||
175                           valid_symbols[CLOSE_PAREN] ||
176                           valid_symbols[CLOSE_BRACKET];
177
178    bool advanced_once = false;
179    if (valid_symbols[ESCAPE_INTERPOLATION] && scanner->delimiters.len > 0 &&
180        (lexer->lookahead == '{' || lexer->lookahead == '}') &&
181        !error_recovery_mode) {
182        Delimiter delimiter = VEC_BACK(scanner->delimiters);
183        if (is_format(&delimiter)) {
184            lexer->mark_end(lexer);
185            bool is_left_brace = lexer->lookahead == '{';
186            advance(lexer);
187            advanced_once = true;
188            if ((lexer->lookahead == '{' && is_left_brace) ||
189                (lexer->lookahead == '}' && !is_left_brace)) {
190                advance(lexer);
191                lexer->mark_end(lexer);
192                lexer->result_symbol = ESCAPE_INTERPOLATION;
193                return true;
194            }
195            return false;
196        }
197    }
198
199    if (valid_symbols[STRING_CONTENT] && scanner->delimiters.len > 0 &&
200        !error_recovery_mode) {
201        Delimiter delimiter = VEC_BACK(scanner->delimiters);
202        int32_t end_char = end_character(&delimiter);
203        bool has_content = advanced_once;
204        while (lexer->lookahead) {
205            if ((advanced_once || lexer->lookahead == '{' ||
206                 lexer->lookahead == '}') &&
207                is_format(&delimiter)) {
208                lexer->mark_end(lexer);
209                lexer->result_symbol = STRING_CONTENT;
210                return has_content;
211            }
212            if (lexer->lookahead == '\\') {
213                if (is_raw(&delimiter)) {
214                    // Step over the backslash.
215                    advance(lexer);
216                    // Step over any escaped quotes.
217                    if (lexer->lookahead == end_character(&delimiter) ||
218                        lexer->lookahead == '\\') {
219                        advance(lexer);
220                    }
221                    // Step over newlines
222                    if (lexer->lookahead == '\r') {
223                        advance(lexer);
224                        if (lexer->lookahead == '\n') {
225                            advance(lexer);
226                        }
227                    } else if (lexer->lookahead == '\n') {
228                        advance(lexer);
229                    }
230                    continue;
231                }
232                if (is_bytes(&delimiter)) {
233                    lexer->mark_end(lexer);
234                    advance(lexer);
235                    if (lexer->lookahead == 'N' || lexer->lookahead == 'u' ||
236                        lexer->lookahead == 'U') {
237                        // In bytes string, \N{...}, \uXXXX and \UXXXXXXXX are
238                        // not escape sequences
239                        // https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals
240                        advance(lexer);
241                    } else {
242                        lexer->result_symbol = STRING_CONTENT;
243                        return has_content;
244                    }
245                } else {
246                    lexer->mark_end(lexer);
247                    lexer->result_symbol = STRING_CONTENT;
248                    return has_content;
249                }
250            } else if (lexer->lookahead == end_char) {
251                if (is_triple(&delimiter)) {
252                    lexer->mark_end(lexer);
253                    advance(lexer);
254                    if (lexer->lookahead == end_char) {
255                        advance(lexer);
256                        if (lexer->lookahead == end_char) {
257                            if (has_content) {
258                                lexer->result_symbol = STRING_CONTENT;
259                            } else {
260                                advance(lexer);
261                                lexer->mark_end(lexer);
262                                VEC_POP(scanner->delimiters);
263                                lexer->result_symbol = STRING_END;
264                                scanner->inside_f_string = false;
265                            }
266                            return true;
267                        }
268                        lexer->mark_end(lexer);
269                        lexer->result_symbol = STRING_CONTENT;
270                        return true;
271                    }
272                    lexer->mark_end(lexer);
273                    lexer->result_symbol = STRING_CONTENT;
274                    return true;
275                }
276                if (has_content) {
277                    lexer->result_symbol = STRING_CONTENT;
278                } else {
279                    advance(lexer);
280                    VEC_POP(scanner->delimiters);
281                    lexer->result_symbol = STRING_END;
282                    scanner->inside_f_string = false;
283                }
284                lexer->mark_end(lexer);
285                return true;
286
287            } else if (lexer->lookahead == '\n' && has_content &&
288                       !is_triple(&delimiter)) {
289                return false;
290            }
291            advance(lexer);
292            has_content = true;
293        }
294    }
295
296    lexer->mark_end(lexer);
297
298    bool found_end_of_line = false;
299    uint32_t indent_length = 0;
300    int32_t first_comment_indent_length = -1;
301    for (;;) {
302        if (lexer->lookahead == '\n') {
303            found_end_of_line = true;
304            indent_length = 0;
305            skip(lexer);
306        } else if (lexer->lookahead == ' ') {
307            indent_length++;
308            skip(lexer);
309        } else if (lexer->lookahead == '\r' || lexer->lookahead == '\f') {
310            indent_length = 0;
311            skip(lexer);
312        } else if (lexer->lookahead == '\t') {
313            indent_length += 8;
314            skip(lexer);
315        } else if (lexer->lookahead == '#' &&
316                   (valid_symbols[INDENT] || valid_symbols[DEDENT] ||
317                    valid_symbols[NEWLINE] || valid_symbols[EXCEPT])) {
318            // If we haven't found an EOL yet,
319            // then this is a comment after an expression:
320            //   foo = bar # comment
321            // Just return, since we don't want to generate an indent/dedent
322            // token.
323            if (!found_end_of_line) {
324                return false;
325            }
326            if (first_comment_indent_length == -1) {
327                first_comment_indent_length = (int32_t)indent_length;
328            }
329            while (lexer->lookahead && lexer->lookahead != '\n') {
330                skip(lexer);
331            }
332            skip(lexer);
333            indent_length = 0;
334        } else if (lexer->lookahead == '\\') {
335            skip(lexer);
336            if (lexer->lookahead == '\r') {
337                skip(lexer);
338            }
339            if (lexer->lookahead == '\n' || lexer->eof(lexer)) {
340                skip(lexer);
341            } else {
342                return false;
343            }
344        } else if (lexer->eof(lexer)) {
345            indent_length = 0;
346            found_end_of_line = true;
347            break;
348        } else {
349            break;
350        }
351    }
352
353    if (found_end_of_line) {
354        if (scanner->indents.len > 0) {
355            uint16_t current_indent_length = VEC_BACK(scanner->indents);
356
357            if (valid_symbols[INDENT] &&
358                indent_length > current_indent_length) {
359                VEC_PUSH(scanner->indents, indent_length);
360                lexer->result_symbol = INDENT;
361                return true;
362            }
363
364            bool next_tok_is_string_start = lexer->lookahead == '\"' ||
365                                            lexer->lookahead == '\'' ||
366                                            lexer->lookahead == '`';
367
368            if ((valid_symbols[DEDENT] ||
369                 (!valid_symbols[NEWLINE] &&
370                  !(valid_symbols[STRING_START] && next_tok_is_string_start) &&
371                  !within_brackets)) &&
372                indent_length < current_indent_length &&
373                !scanner->inside_f_string &&
374
375                // Wait to create a dedent token until we've consumed any
376                // comments
377                // whose indentation matches the current block.
378                first_comment_indent_length < (int32_t)current_indent_length) {
379                VEC_POP(scanner->indents);
380                lexer->result_symbol = DEDENT;
381                return true;
382            }
383        }
384
385        if (valid_symbols[NEWLINE] && !error_recovery_mode) {
386            lexer->result_symbol = NEWLINE;
387            return true;
388        }
389    }
390
391    if (first_comment_indent_length == -1 && valid_symbols[STRING_START]) {
392        Delimiter delimiter = new_delimiter();
393
394        bool has_flags = false;
395        while (lexer->lookahead) {
396            if (lexer->lookahead == 'f' || lexer->lookahead == 'F') {
397                set_format(&delimiter);
398            } else if (lexer->lookahead == 'r' || lexer->lookahead == 'R') {
399                set_raw(&delimiter);
400            } else if (lexer->lookahead == 'b' || lexer->lookahead == 'B') {
401                set_bytes(&delimiter);
402            } else if (lexer->lookahead != 'u' && lexer->lookahead != 'U') {
403                break;
404            }
405            has_flags = true;
406            advance(lexer);
407        }
408
409        if (lexer->lookahead == '`') {
410            set_end_character(&delimiter, '`');
411            advance(lexer);
412            lexer->mark_end(lexer);
413        } else if (lexer->lookahead == '\'') {
414            set_end_character(&delimiter, '\'');
415            advance(lexer);
416            lexer->mark_end(lexer);
417            if (lexer->lookahead == '\'') {
418                advance(lexer);
419                if (lexer->lookahead == '\'') {
420                    advance(lexer);
421                    lexer->mark_end(lexer);
422                    set_triple(&delimiter);
423                }
424            }
425        } else if (lexer->lookahead == '"') {
426            set_end_character(&delimiter, '"');
427            advance(lexer);
428            lexer->mark_end(lexer);
429            if (lexer->lookahead == '"') {
430                advance(lexer);
431                if (lexer->lookahead == '"') {
432                    advance(lexer);
433                    lexer->mark_end(lexer);
434                    set_triple(&delimiter);
435                }
436            }
437        }
438
439        if (end_character(&delimiter)) {
440            VEC_PUSH(scanner->delimiters, delimiter);
441            lexer->result_symbol = STRING_START;
442            scanner->inside_f_string = is_format(&delimiter);
443            return true;
444        }
445        if (has_flags) {
446            return false;
447        }
448    }
449
450    return false;
451}
452
453unsigned tree_sitter_python_external_scanner_serialize(void *payload,
454                                                       char *buffer) {
455    Scanner *scanner = (Scanner *)payload;
456
457    size_t size = 0;
458
459    buffer[size++] = (char)scanner->inside_f_string;
460
461    size_t delimiter_count = scanner->delimiters.len;
462    if (delimiter_count > UINT8_MAX) {
463        delimiter_count = UINT8_MAX;
464    }
465    buffer[size++] = (char)delimiter_count;
466
467    if (delimiter_count > 0) {
468        memcpy(&buffer[size], scanner->delimiters.data, delimiter_count);
469    }
470    size += delimiter_count;
471
472    int iter = 1;
473    for (; iter < scanner->indents.len &&
474           size < TREE_SITTER_SERIALIZATION_BUFFER_SIZE;
475         ++iter) {
476        buffer[size++] = (char)scanner->indents.data[iter];
477    }
478
479    return size;
480}
481
482void tree_sitter_python_external_scanner_deserialize(void *payload,
483                                                     const char *buffer,
484                                                     unsigned length) {
485    Scanner *scanner = (Scanner *)payload;
486
487    VEC_CLEAR(scanner->delimiters);
488    VEC_CLEAR(scanner->indents);
489    VEC_PUSH(scanner->indents, 0);
490
491    if (length > 0) {
492        size_t size = 0;
493
494        scanner->inside_f_string = (bool)buffer[size++];
495
496        size_t delimiter_count = (uint8_t)buffer[size++];
497        if (delimiter_count > 0) {
498            VEC_GROW(scanner->delimiters, delimiter_count);
499            scanner->delimiters.len = delimiter_count;
500            memcpy(scanner->delimiters.data, &buffer[size], delimiter_count);
501            size += delimiter_count;
502        }
503
504        for (; size < length; size++) {
505            VEC_PUSH(scanner->indents, (unsigned char)buffer[size]);
506        }
507    }
508}
509
510void *tree_sitter_python_external_scanner_create() {
511#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)
512    _Static_assert(sizeof(Delimiter) == sizeof(char), "");
513#else
514    assert(sizeof(Delimiter) == sizeof(char));
515#endif
516    Scanner *scanner = calloc(1, sizeof(Scanner));
517    scanner->indents = indent_vec_new();
518    scanner->delimiters = delimiter_vec_new();
519    tree_sitter_python_external_scanner_deserialize(scanner, NULL, 0);
520    return scanner;
521}
522
523void tree_sitter_python_external_scanner_destroy(void *payload) {
524    Scanner *scanner = (Scanner *)payload;
525    VEC_FREE(scanner->indents);
526    VEC_FREE(scanner->delimiters);
527    free(scanner);
528}