1#include "../array.h"
  2#include "parser.h"
  3
  4#include <assert.h>
  5#include <stdint.h>
  6#include <stdio.h>
  7#include <string.h>
  8
  9enum TokenType {
 10    NEWLINE,
 11    INDENT,
 12    DEDENT,
 13    STRING_START,
 14    STRING_CONTENT,
 15    ESCAPE_INTERPOLATION,
 16    STRING_END,
 17    COMMENT,
 18    CLOSE_PAREN,
 19    CLOSE_BRACKET,
 20    CLOSE_BRACE,
 21    EXCEPT,
 22};
 23
 24typedef enum {
 25    SingleQuote = 1 << 0,
 26    DoubleQuote = 1 << 1,
 27    BackQuote = 1 << 2,
 28    Raw = 1 << 3,
 29    Format = 1 << 4,
 30    Triple = 1 << 5,
 31    Bytes = 1 << 6,
 32} Flags;
 33
 34typedef struct {
 35    char flags;
 36} Delimiter;
 37
 38static inline Delimiter new_delimiter() { return (Delimiter){0}; }
 39
 40static inline bool is_format(Delimiter *delimiter) { return delimiter->flags & Format; }
 41
 42static inline bool is_raw(Delimiter *delimiter) { return delimiter->flags & Raw; }
 43
 44static inline bool is_triple(Delimiter *delimiter) { return delimiter->flags & Triple; }
 45
 46static inline bool is_bytes(Delimiter *delimiter) { return delimiter->flags & Bytes; }
 47
 48static inline int32_t end_character(Delimiter *delimiter) {
 49    if (delimiter->flags & SingleQuote) {
 50        return '\'';
 51    }
 52    if (delimiter->flags & DoubleQuote) {
 53        return '"';
 54    }
 55    if (delimiter->flags & BackQuote) {
 56        return '`';
 57    }
 58    return 0;
 59}
 60
 61static inline void set_format(Delimiter *delimiter) { delimiter->flags |= Format; }
 62
 63static inline void set_raw(Delimiter *delimiter) { delimiter->flags |= Raw; }
 64
 65static inline void set_triple(Delimiter *delimiter) { delimiter->flags |= Triple; }
 66
 67static inline void set_bytes(Delimiter *delimiter) { delimiter->flags |= Bytes; }
 68
 69static inline void set_end_character(Delimiter *delimiter, int32_t character) {
 70    switch (character) {
 71        case '\'':
 72            delimiter->flags |= SingleQuote;
 73            break;
 74        case '"':
 75            delimiter->flags |= DoubleQuote;
 76            break;
 77        case '`':
 78            delimiter->flags |= BackQuote;
 79            break;
 80        default:
 81            assert(false);
 82    }
 83}
 84
 85typedef struct {
 86    Array(uint16_t) indents;
 87    Array(Delimiter) delimiters;
 88    bool inside_f_string;
 89} Scanner;
 90
 91static inline void advance(TSLexer *lexer) { lexer->advance(lexer, false); }
 92
 93static inline void skip(TSLexer *lexer) { lexer->advance(lexer, true); }
 94
 95bool tree_sitter_python_external_scanner_scan(void *payload, TSLexer *lexer, const bool *valid_symbols) {
 96    Scanner *scanner = (Scanner *)payload;
 97
 98    bool error_recovery_mode = valid_symbols[STRING_CONTENT] && valid_symbols[INDENT];
 99    bool within_brackets = valid_symbols[CLOSE_BRACE] || valid_symbols[CLOSE_PAREN] || valid_symbols[CLOSE_BRACKET];
100
101    bool advanced_once = false;
102    if (valid_symbols[ESCAPE_INTERPOLATION] && scanner->delimiters.size > 0 &&
103        (lexer->lookahead == '{' || lexer->lookahead == '}') && !error_recovery_mode) {
104        Delimiter *delimiter = array_back(&scanner->delimiters);
105        if (is_format(delimiter)) {
106            lexer->mark_end(lexer);
107            bool is_left_brace = lexer->lookahead == '{';
108            advance(lexer);
109            advanced_once = true;
110            if ((lexer->lookahead == '{' && is_left_brace) || (lexer->lookahead == '}' && !is_left_brace)) {
111                advance(lexer);
112                lexer->mark_end(lexer);
113                lexer->result_symbol = ESCAPE_INTERPOLATION;
114                return true;
115            }
116            return false;
117        }
118    }
119
120    if (valid_symbols[STRING_CONTENT] && scanner->delimiters.size > 0 && !error_recovery_mode) {
121        Delimiter *delimiter = array_back(&scanner->delimiters);
122        int32_t end_char = end_character(delimiter);
123        bool has_content = advanced_once;
124        while (lexer->lookahead) {
125            if ((advanced_once || lexer->lookahead == '{' || lexer->lookahead == '}') && is_format(delimiter)) {
126                lexer->mark_end(lexer);
127                lexer->result_symbol = STRING_CONTENT;
128                return has_content;
129            }
130            if (lexer->lookahead == '\\') {
131                if (is_raw(delimiter)) {
132                    // Step over the backslash.
133                    advance(lexer);
134                    // Step over any escaped quotes.
135                    if (lexer->lookahead == end_character(delimiter) || lexer->lookahead == '\\') {
136                        advance(lexer);
137                    }
138                    // Step over newlines
139                    if (lexer->lookahead == '\r') {
140                        advance(lexer);
141                        if (lexer->lookahead == '\n') {
142                            advance(lexer);
143                        }
144                    } else if (lexer->lookahead == '\n') {
145                        advance(lexer);
146                    }
147                    continue;
148                }
149                if (is_bytes(delimiter)) {
150                    lexer->mark_end(lexer);
151                    advance(lexer);
152                    if (lexer->lookahead == 'N' || lexer->lookahead == 'u' || lexer->lookahead == 'U') {
153                        // In bytes string, \N{...}, \uXXXX and \UXXXXXXXX are
154                        // not escape sequences
155                        // https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals
156                        advance(lexer);
157                    } else {
158                        lexer->result_symbol = STRING_CONTENT;
159                        return has_content;
160                    }
161                } else {
162                    lexer->mark_end(lexer);
163                    lexer->result_symbol = STRING_CONTENT;
164                    return has_content;
165                }
166            } else if (lexer->lookahead == end_char) {
167                if (is_triple(delimiter)) {
168                    lexer->mark_end(lexer);
169                    advance(lexer);
170                    if (lexer->lookahead == end_char) {
171                        advance(lexer);
172                        if (lexer->lookahead == end_char) {
173                            if (has_content) {
174                                lexer->result_symbol = STRING_CONTENT;
175                            } else {
176                                advance(lexer);
177                                lexer->mark_end(lexer);
178                                array_pop(&scanner->delimiters);
179                                lexer->result_symbol = STRING_END;
180                                scanner->inside_f_string = false;
181                            }
182                            return true;
183                        }
184                        lexer->mark_end(lexer);
185                        lexer->result_symbol = STRING_CONTENT;
186                        return true;
187                    }
188                    lexer->mark_end(lexer);
189                    lexer->result_symbol = STRING_CONTENT;
190                    return true;
191                }
192                if (has_content) {
193                    lexer->result_symbol = STRING_CONTENT;
194                } else {
195                    advance(lexer);
196                    array_pop(&scanner->delimiters);
197                    lexer->result_symbol = STRING_END;
198                    scanner->inside_f_string = false;
199                }
200                lexer->mark_end(lexer);
201                return true;
202
203            } else if (lexer->lookahead == '\n' && has_content && !is_triple(delimiter)) {
204                return false;
205            }
206            advance(lexer);
207            has_content = true;
208        }
209    }
210
211    lexer->mark_end(lexer);
212
213    bool found_end_of_line = false;
214    uint32_t indent_length = 0;
215    int32_t first_comment_indent_length = -1;
216    for (;;) {
217        if (lexer->lookahead == '\n') {
218            found_end_of_line = true;
219            indent_length = 0;
220            skip(lexer);
221        } else if (lexer->lookahead == ' ') {
222            indent_length++;
223            skip(lexer);
224        } else if (lexer->lookahead == '\r' || lexer->lookahead == '\f') {
225            indent_length = 0;
226            skip(lexer);
227        } else if (lexer->lookahead == '\t') {
228            indent_length += 8;
229            skip(lexer);
230        } else if (lexer->lookahead == '#' && (valid_symbols[INDENT] || valid_symbols[DEDENT] ||
231                                               valid_symbols[NEWLINE] || valid_symbols[EXCEPT])) {
232            // If we haven't found an EOL yet,
233            // then this is a comment after an expression:
234            //   foo = bar # comment
235            // Just return, since we don't want to generate an indent/dedent
236            // token.
237            if (!found_end_of_line) {
238                return false;
239            }
240            if (first_comment_indent_length == -1) {
241                first_comment_indent_length = (int32_t)indent_length;
242            }
243            while (lexer->lookahead && lexer->lookahead != '\n') {
244                skip(lexer);
245            }
246            skip(lexer);
247            indent_length = 0;
248        } else if (lexer->lookahead == '\\') {
249            skip(lexer);
250            if (lexer->lookahead == '\r') {
251                skip(lexer);
252            }
253            if (lexer->lookahead == '\n' || lexer->eof(lexer)) {
254                skip(lexer);
255            } else {
256                return false;
257            }
258        } else if (lexer->eof(lexer)) {
259            indent_length = 0;
260            found_end_of_line = true;
261            break;
262        } else {
263            break;
264        }
265    }
266
267    if (found_end_of_line) {
268        if (scanner->indents.size > 0) {
269            uint16_t current_indent_length = *array_back(&scanner->indents);
270
271            if (valid_symbols[INDENT] && indent_length > current_indent_length) {
272                array_push(&scanner->indents, indent_length);
273                lexer->result_symbol = INDENT;
274                return true;
275            }
276
277            bool next_tok_is_string_start =
278                lexer->lookahead == '\"' || lexer->lookahead == '\'' || lexer->lookahead == '`';
279
280            if ((valid_symbols[DEDENT] ||
281                 (!valid_symbols[NEWLINE] && !(valid_symbols[STRING_START] && next_tok_is_string_start) &&
282                  !within_brackets)) &&
283                indent_length < current_indent_length && !scanner->inside_f_string &&
284
285                // Wait to create a dedent token until we've consumed any
286                // comments
287                // whose indentation matches the current block.
288                first_comment_indent_length < (int32_t)current_indent_length) {
289                array_pop(&scanner->indents);
290                lexer->result_symbol = DEDENT;
291                return true;
292            }
293        }
294
295        if (valid_symbols[NEWLINE] && !error_recovery_mode) {
296            lexer->result_symbol = NEWLINE;
297            return true;
298        }
299    }
300
301    if (first_comment_indent_length == -1 && valid_symbols[STRING_START]) {
302        Delimiter delimiter = new_delimiter();
303
304        bool has_flags = false;
305        while (lexer->lookahead) {
306            if (lexer->lookahead == 'f' || lexer->lookahead == 'F') {
307                set_format(&delimiter);
308            } else if (lexer->lookahead == 'r' || lexer->lookahead == 'R') {
309                set_raw(&delimiter);
310            } else if (lexer->lookahead == 'b' || lexer->lookahead == 'B') {
311                set_bytes(&delimiter);
312            } else if (lexer->lookahead != 'u' && lexer->lookahead != 'U') {
313                break;
314            }
315            has_flags = true;
316            advance(lexer);
317        }
318
319        if (lexer->lookahead == '`') {
320            set_end_character(&delimiter, '`');
321            advance(lexer);
322            lexer->mark_end(lexer);
323        } else if (lexer->lookahead == '\'') {
324            set_end_character(&delimiter, '\'');
325            advance(lexer);
326            lexer->mark_end(lexer);
327            if (lexer->lookahead == '\'') {
328                advance(lexer);
329                if (lexer->lookahead == '\'') {
330                    advance(lexer);
331                    lexer->mark_end(lexer);
332                    set_triple(&delimiter);
333                }
334            }
335        } else if (lexer->lookahead == '"') {
336            set_end_character(&delimiter, '"');
337            advance(lexer);
338            lexer->mark_end(lexer);
339            if (lexer->lookahead == '"') {
340                advance(lexer);
341                if (lexer->lookahead == '"') {
342                    advance(lexer);
343                    lexer->mark_end(lexer);
344                    set_triple(&delimiter);
345                }
346            }
347        }
348
349        if (end_character(&delimiter)) {
350            array_push(&scanner->delimiters, delimiter);
351            lexer->result_symbol = STRING_START;
352            scanner->inside_f_string = is_format(&delimiter);
353            return true;
354        }
355        if (has_flags) {
356            return false;
357        }
358    }
359
360    return false;
361}
362
363unsigned tree_sitter_python_external_scanner_serialize(void *payload, char *buffer) {
364    Scanner *scanner = (Scanner *)payload;
365
366    size_t size = 0;
367
368    buffer[size++] = (char)scanner->inside_f_string;
369
370    size_t delimiter_count = scanner->delimiters.size;
371    if (delimiter_count > UINT8_MAX) {
372        delimiter_count = UINT8_MAX;
373    }
374    buffer[size++] = (char)delimiter_count;
375
376    if (delimiter_count > 0) {
377        memcpy(&buffer[size], scanner->delimiters.contents, delimiter_count);
378    }
379    size += delimiter_count;
380
381    uint32_t iter = 1;
382    for (; iter < scanner->indents.size && size < TREE_SITTER_SERIALIZATION_BUFFER_SIZE; ++iter) {
383        buffer[size++] = (char)*array_get(&scanner->indents, iter);
384    }
385
386    return size;
387}
388
389void tree_sitter_python_external_scanner_deserialize(void *payload, const char *buffer, unsigned length) {
390    Scanner *scanner = (Scanner *)payload;
391
392    array_delete(&scanner->delimiters);
393    array_delete(&scanner->indents);
394    array_push(&scanner->indents, 0);
395
396    if (length > 0) {
397        size_t size = 0;
398
399        scanner->inside_f_string = (bool)buffer[size++];
400
401        size_t delimiter_count = (uint8_t)buffer[size++];
402        if (delimiter_count > 0) {
403            array_reserve(&scanner->delimiters, delimiter_count);
404            scanner->delimiters.size = delimiter_count;
405            memcpy(scanner->delimiters.contents, &buffer[size], delimiter_count);
406            size += delimiter_count;
407        }
408
409        for (; size < length; size++) {
410            array_push(&scanner->indents, (unsigned char)buffer[size]);
411        }
412    }
413}
414
415void *tree_sitter_python_external_scanner_create() {
416#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)
417    _Static_assert(sizeof(Delimiter) == sizeof(char), "");
418#else
419    assert(sizeof(Delimiter) == sizeof(char));
420#endif
421    Scanner *scanner = calloc(1, sizeof(Scanner));
422    array_init(&scanner->indents);
423    array_init(&scanner->delimiters);
424    tree_sitter_python_external_scanner_deserialize(scanner, NULL, 0);
425    return scanner;
426}
427
428void tree_sitter_python_external_scanner_destroy(void *payload) {
429    Scanner *scanner = (Scanner *)payload;
430    array_delete(&scanner->indents);
431    array_delete(&scanner->delimiters);
432    free(scanner);
433}