1#include "tree_sitter/alloc.h"
  2#include "tree_sitter/parser.h"
  3
  4#include <wctype.h>
  5
  6enum TokenType {
  7    STRING_CONTENT,
  8    RAW_STRING_LITERAL_START,
  9    RAW_STRING_LITERAL_CONTENT,
 10    RAW_STRING_LITERAL_END,
 11    FLOAT_LITERAL,
 12    BLOCK_OUTER_DOC_MARKER,
 13    BLOCK_INNER_DOC_MARKER,
 14    BLOCK_COMMENT_CONTENT,
 15    LINE_DOC_CONTENT,
 16    ERROR_SENTINEL
 17};
 18
 19typedef struct {
 20    uint8_t opening_hash_count;
 21} Scanner;
 22
 23void *tree_sitter_rust_external_scanner_create() { return ts_calloc(1, sizeof(Scanner)); }
 24
 25void tree_sitter_rust_external_scanner_destroy(void *payload) { ts_free((Scanner *)payload); }
 26
 27unsigned tree_sitter_rust_external_scanner_serialize(void *payload, char *buffer) {
 28    Scanner *scanner = (Scanner *)payload;
 29    buffer[0] = (char)scanner->opening_hash_count;
 30    return 1;
 31}
 32
 33void tree_sitter_rust_external_scanner_deserialize(void *payload, const char *buffer, unsigned length) {
 34    Scanner *scanner = (Scanner *)payload;
 35    scanner->opening_hash_count = 0;
 36    if (length == 1) {
 37        Scanner *scanner = (Scanner *)payload;
 38        scanner->opening_hash_count = buffer[0];
 39    }
 40}
 41
 42static inline bool is_num_char(int32_t c) { return c == '_' || iswdigit(c); }
 43
 44static inline void advance(TSLexer *lexer) { lexer->advance(lexer, false); }
 45
 46static inline void skip(TSLexer *lexer) { lexer->advance(lexer, true); }
 47
 48static inline bool process_string(TSLexer *lexer) {
 49    bool has_content = false;
 50    for (;;) {
 51        if (lexer->lookahead == '\"' || lexer->lookahead == '\\') {
 52            break;
 53        }
 54        if (lexer->eof(lexer)) {
 55            return false;
 56        }
 57        has_content = true;
 58        advance(lexer);
 59    }
 60    lexer->result_symbol = STRING_CONTENT;
 61    lexer->mark_end(lexer);
 62    return has_content;
 63}
 64
 65static inline bool scan_raw_string_start(Scanner *scanner, TSLexer *lexer) {
 66    if (lexer->lookahead == 'b' || lexer->lookahead == 'c') {
 67        advance(lexer);
 68    }
 69    if (lexer->lookahead != 'r') {
 70        return false;
 71    }
 72    advance(lexer);
 73
 74    uint8_t opening_hash_count = 0;
 75    while (lexer->lookahead == '#') {
 76        advance(lexer);
 77        opening_hash_count++;
 78    }
 79
 80    if (lexer->lookahead != '"') {
 81        return false;
 82    }
 83    advance(lexer);
 84    scanner->opening_hash_count = opening_hash_count;
 85
 86    lexer->result_symbol = RAW_STRING_LITERAL_START;
 87    return true;
 88}
 89
 90static inline bool scan_raw_string_content(Scanner *scanner, TSLexer *lexer) {
 91    for (;;) {
 92        if (lexer->eof(lexer)) {
 93            return false;
 94        }
 95        if (lexer->lookahead == '"') {
 96            lexer->mark_end(lexer);
 97            advance(lexer);
 98            unsigned hash_count = 0;
 99            while (lexer->lookahead == '#' && hash_count < scanner->opening_hash_count) {
100                advance(lexer);
101                hash_count++;
102            }
103            if (hash_count == scanner->opening_hash_count) {
104                lexer->result_symbol = RAW_STRING_LITERAL_CONTENT;
105                return true;
106            }
107        } else {
108            advance(lexer);
109        }
110    }
111}
112
113static inline bool scan_raw_string_end(Scanner *scanner, TSLexer *lexer) {
114    advance(lexer);
115    for (unsigned i = 0; i < scanner->opening_hash_count; i++) {
116        advance(lexer);
117    }
118    lexer->result_symbol = RAW_STRING_LITERAL_END;
119    return true;
120}
121
122static inline bool process_float_literal(TSLexer *lexer) {
123    lexer->result_symbol = FLOAT_LITERAL;
124
125    advance(lexer);
126    while (is_num_char(lexer->lookahead)) {
127        advance(lexer);
128    }
129
130    bool has_fraction = false, has_exponent = false;
131
132    if (lexer->lookahead == '.') {
133        has_fraction = true;
134        advance(lexer);
135        if (iswalpha(lexer->lookahead)) {
136            // The dot is followed by a letter: 1.max(2) => not a float but an integer
137            return false;
138        }
139
140        if (lexer->lookahead == '.') {
141            return false;
142        }
143        while (is_num_char(lexer->lookahead)) {
144            advance(lexer);
145        }
146    }
147
148    lexer->mark_end(lexer);
149
150    if (lexer->lookahead == 'e' || lexer->lookahead == 'E') {
151        has_exponent = true;
152        advance(lexer);
153        if (lexer->lookahead == '+' || lexer->lookahead == '-') {
154            advance(lexer);
155        }
156        if (!is_num_char(lexer->lookahead)) {
157            return true;
158        }
159        advance(lexer);
160        while (is_num_char(lexer->lookahead)) {
161            advance(lexer);
162        }
163
164        lexer->mark_end(lexer);
165    }
166
167    if (!has_exponent && !has_fraction) {
168        return false;
169    }
170
171    if (lexer->lookahead != 'u' && lexer->lookahead != 'i' && lexer->lookahead != 'f') {
172        return true;
173    }
174    advance(lexer);
175    if (!iswdigit(lexer->lookahead)) {
176        return true;
177    }
178
179    while (iswdigit(lexer->lookahead)) {
180        advance(lexer);
181    }
182
183    lexer->mark_end(lexer);
184    return true;
185}
186
187static inline bool process_line_doc_content(TSLexer *lexer) {
188    lexer->result_symbol = LINE_DOC_CONTENT;
189    for (;;) {
190        if (lexer->eof(lexer)) {
191            return true;
192        }
193        if (lexer->lookahead == '\n') {
194            // Include the newline in the doc content node.
195            // Line endings are useful for markdown injection.
196            advance(lexer);
197            return true;
198        }
199        advance(lexer);
200    }
201}
202
203typedef enum {
204    LeftForwardSlash,
205    LeftAsterisk,
206    Continuing,
207} BlockCommentState;
208
209typedef struct {
210    BlockCommentState state;
211    unsigned nestingDepth;
212} BlockCommentProcessing;
213
214static inline void process_left_forward_slash(BlockCommentProcessing *processing, char current) {
215    if (current == '*') {
216        processing->nestingDepth += 1;
217    }
218    processing->state = Continuing;
219};
220
221static inline void process_left_asterisk(BlockCommentProcessing *processing, char current, TSLexer *lexer) {
222    if (current == '*') {
223        lexer->mark_end(lexer);
224        processing->state = LeftAsterisk;
225        return;
226    }
227
228    if (current == '/') {
229        processing->nestingDepth -= 1;
230    }
231
232    processing->state = Continuing;
233}
234
235static inline void process_continuing(BlockCommentProcessing *processing, char current) {
236    switch (current) {
237        case '/':
238            processing->state = LeftForwardSlash;
239            break;
240        case '*':
241            processing->state = LeftAsterisk;
242            break;
243    }
244}
245
246static inline bool process_block_comment(TSLexer *lexer, const bool *valid_symbols) {
247    char first = (char)lexer->lookahead;
248    // The first character is stored so we can safely advance inside
249    // these if blocks. However, because we only store one, we can only
250    // safely advance 1 time. Since there's a chance that an advance could
251    // happen in one state, we must advance in all states to ensure that
252    // the program ends up in a sane state prior to processing the block
253    // comment if need be.
254    if (valid_symbols[BLOCK_INNER_DOC_MARKER] && first == '!') {
255        lexer->result_symbol = BLOCK_INNER_DOC_MARKER;
256        advance(lexer);
257        return true;
258    }
259    if (valid_symbols[BLOCK_OUTER_DOC_MARKER] && first == '*') {
260        advance(lexer);
261        lexer->mark_end(lexer);
262        // If the next token is a / that means that it's an empty block comment.
263        if (lexer->lookahead == '/') {
264            return false;
265        }
266        // If the next token is a * that means that this isn't a BLOCK_OUTER_DOC_MARKER
267        // as BLOCK_OUTER_DOC_MARKER's only have 2 * not 3 or more.
268        if (lexer->lookahead != '*') {
269            lexer->result_symbol = BLOCK_OUTER_DOC_MARKER;
270            return true;
271        }
272    } else {
273        advance(lexer);
274    }
275
276    if (valid_symbols[BLOCK_COMMENT_CONTENT]) {
277        BlockCommentProcessing processing = {Continuing, 1};
278        // Manually set the current state based on the first character
279        switch (first) {
280            case '*':
281                processing.state = LeftAsterisk;
282                if (lexer->lookahead == '/') {
283                    // This case can happen in an empty doc block comment
284                    // like /*!*/. The comment has no contents, so bail.
285                    return false;
286                }
287                break;
288            case '/':
289                processing.state = LeftForwardSlash;
290                break;
291            default:
292                processing.state = Continuing;
293                break;
294        }
295
296        // For the purposes of actually parsing rust code, this
297        // is incorrect as it considers an unterminated block comment
298        // to be an error. However, for the purposes of syntax highlighting
299        // this should be considered successful as otherwise you are not able
300        // to syntax highlight a block of code prior to closing the
301        // block comment
302        while (!lexer->eof(lexer) && processing.nestingDepth != 0) {
303            // Set first to the current lookahead as that is the second character
304            // as we force an advance in the above code when we are checking if we
305            // need to handle a block comment inner or outer doc comment signifier
306            // node
307            first = (char)lexer->lookahead;
308            switch (processing.state) {
309                case LeftForwardSlash:
310                    process_left_forward_slash(&processing, first);
311                    break;
312                case LeftAsterisk:
313                    process_left_asterisk(&processing, first, lexer);
314                    break;
315                case Continuing:
316                    lexer->mark_end(lexer);
317                    process_continuing(&processing, first);
318                    break;
319                default:
320                    break;
321            }
322            advance(lexer);
323            if (first == '/' && processing.nestingDepth != 0) {
324                lexer->mark_end(lexer);
325            }
326        }
327        lexer->result_symbol = BLOCK_COMMENT_CONTENT;
328        return true;
329    }
330
331    return false;
332}
333
334bool tree_sitter_rust_external_scanner_scan(void *payload, TSLexer *lexer, const bool *valid_symbols) {
335    // The documentation states that if the lexical analysis fails for some reason
336    // they will mark every state as valid and pass it to the external scanner
337    // However, we can't do anything to help them recover in that case so we
338    // should just fail.
339    /*
340      link: https://tree-sitter.github.io/tree-sitter/creating-parsers#external-scanners
341      If a syntax error is encountered during regular parsing, Tree-sitter’s
342      first action during error recovery will be to call the external scanner’s
343      scan function with all tokens marked valid. The scanner should detect this
344      case and handle it appropriately. One simple method of detection is to add
345      an unused token to the end of the externals array, for example
346
347      externals: $ => [$.token1, $.token2, $.error_sentinel],
348
349      then check whether that token is marked valid to determine whether
350      Tree-sitter is in error correction mode.
351    */
352    if (valid_symbols[ERROR_SENTINEL]) {
353        return false;
354    }
355
356    Scanner *scanner = (Scanner *)payload;
357
358    if (valid_symbols[BLOCK_COMMENT_CONTENT] || valid_symbols[BLOCK_INNER_DOC_MARKER] ||
359        valid_symbols[BLOCK_OUTER_DOC_MARKER]) {
360        return process_block_comment(lexer, valid_symbols);
361    }
362
363    if (valid_symbols[STRING_CONTENT] && !valid_symbols[FLOAT_LITERAL]) {
364        return process_string(lexer);
365    }
366
367    if (valid_symbols[LINE_DOC_CONTENT]) {
368        return process_line_doc_content(lexer);
369    }
370
371    while (iswspace(lexer->lookahead)) {
372        skip(lexer);
373    }
374
375    if (valid_symbols[RAW_STRING_LITERAL_START] &&
376        (lexer->lookahead == 'r' || lexer->lookahead == 'b' || lexer->lookahead == 'c')) {
377        return scan_raw_string_start(scanner, lexer);
378    }
379
380    if (valid_symbols[RAW_STRING_LITERAL_CONTENT]) {
381        return scan_raw_string_content(scanner, lexer);
382    }
383
384    if (valid_symbols[RAW_STRING_LITERAL_END] && lexer->lookahead == '"') {
385        return scan_raw_string_end(scanner, lexer);
386    }
387
388    if (valid_symbols[FLOAT_LITERAL] && iswdigit(lexer->lookahead)) {
389        return process_float_literal(lexer);
390    }
391
392    return false;
393}