1#include "tree_sitter/array.h"
  2#include "tree_sitter/parser.h"
  3
  4#include <wctype.h>
  5
  6enum TokenType {
  7    SEMI,
  8    CLASS_MEMBER_SEMI,
  9    BLOCK_COMMENT,
 10    NOT_IS,
 11    IN,
 12    Q_DOT,
 13    MULTILINE_STRING_CONTENT,
 14    CONSTRUCTOR,
 15    GET,
 16    SET,
 17    DOLLAR,
 18};
 19
 20#define MAX_WORD_SIZE 16
 21#define MAX_WORDS 16
 22
 23static inline void advance(TSLexer *lexer) { lexer->advance(lexer, false); }
 24
 25static inline void skip(TSLexer *lexer) { lexer->advance(lexer, true); }
 26
 27static inline bool scan_whitespace_and_comments(TSLexer *lexer) {
 28    while (iswspace(lexer->lookahead)) {
 29        skip(lexer);
 30    }
 31    return lexer->lookahead != '/';
 32}
 33
 34static bool scan_word(TSLexer *lexer, const char *const word) {
 35    for (uint8_t i = 0; word[i] != '\0'; i++) {
 36        if (lexer->lookahead != word[i]) {
 37            return false;
 38        }
 39        skip(lexer);
 40    }
 41    return true;
 42}
 43
 44static bool scan_words(TSLexer *lexer, const char words[MAX_WORDS][MAX_WORD_SIZE], char scanned_word[16],
 45                       uint8_t *index) {
 46    if (!scanned_word[0]) {
 47        for (uint8_t i = 0; i < MAX_WORD_SIZE - 1; i++) {
 48            if (!iswalpha(lexer->lookahead)) {
 49                if (i == 0) {
 50                    return false;
 51                }
 52                break;
 53            }
 54            scanned_word[i] = (char)lexer->lookahead;
 55            skip(lexer);
 56        }
 57    }
 58
 59    for (uint8_t i = 0; i < MAX_WORDS; i++) {
 60        if (strncmp(scanned_word, words[i], MAX_WORD_SIZE) == 0) {
 61            if (index != NULL) {
 62                *index = i;
 63            }
 64            return true;
 65        }
 66    }
 67
 68    return false;
 69}
 70
 71void *tree_sitter_kotlin_external_scanner_create() { return NULL; }
 72
 73void tree_sitter_kotlin_external_scanner_destroy(void *payload) {}
 74
 75unsigned tree_sitter_kotlin_external_scanner_serialize(void *payload, char *buffer) { return 0; }
 76
 77void tree_sitter_kotlin_external_scanner_deserialize(void *payload, const char *buffer, unsigned length) {}
 78
 79bool tree_sitter_kotlin_external_scanner_scan(void *payload, TSLexer *lexer, const bool *valid_symbols) {
 80    if (valid_symbols[MULTILINE_STRING_CONTENT]) {
 81        bool did_advance = false;
 82        lexer->result_symbol = MULTILINE_STRING_CONTENT;
 83        while (!lexer->eof(lexer)) {
 84            switch (lexer->lookahead) {
 85                case '$':
 86                    lexer->mark_end(lexer);
 87                    advance(lexer);
 88                    if (iswalpha(lexer->lookahead) || lexer->lookahead == '{') {
 89                        return did_advance;
 90                    }
 91                    did_advance = true;
 92                    break;
 93                case '"':
 94                    lexer->mark_end(lexer);
 95                    // 3 or 4 quotes means we're done
 96                    advance(lexer);
 97                    if (lexer->lookahead == '"') {
 98                        advance(lexer);
 99                        if (lexer->lookahead == '"') {
100                            advance(lexer);
101                            if (lexer->lookahead == '"') {
102                                advance(lexer);
103                            }
104                            return did_advance;
105                        }
106                    }
107                    did_advance = true;
108                    break;
109                default:
110                    advance(lexer);
111                    did_advance = true;
112                    break;
113            }
114        }
115    }
116
117    if (valid_symbols[SEMI] || valid_symbols[CLASS_MEMBER_SEMI]) {
118        lexer->result_symbol = valid_symbols[SEMI] ? SEMI : CLASS_MEMBER_SEMI;
119        lexer->mark_end(lexer);
120        bool saw_newline = false;
121        for (;;) {
122            if (lexer->eof(lexer)) {
123                return true;
124            }
125
126            if (lexer->lookahead == ';') {
127                advance(lexer);
128                lexer->mark_end(lexer);
129                return true;
130            }
131
132            if (!iswspace(lexer->lookahead)) {
133                break;
134            }
135
136            if (lexer->lookahead == '\n') {
137                skip(lexer);
138                saw_newline = true;
139                break;
140            }
141
142            if (lexer->lookahead == '\r') {
143                skip(lexer);
144
145                if (lexer->lookahead == '\n') {
146                    skip(lexer);
147                }
148
149                saw_newline = true;
150                break;
151            }
152
153            skip(lexer);
154        }
155
156        // Skip whitespace and comments
157        while (iswspace(lexer->lookahead)) {
158            skip(lexer);
159        }
160        if (lexer->lookahead == '/') {
161            goto comment;
162        }
163
164        if (!saw_newline) {
165            switch (lexer->lookahead) {
166                case '!':
167                    skip(lexer);
168                    goto continue_not_is_from_semi;
169                case '?':
170                    if (valid_symbols[Q_DOT]) {
171                        goto q_dot_from_semi;
172                    }
173                    return false;
174                case 'i':
175                    return scan_word(lexer, "import");
176                case ';':
177                    advance(lexer);
178                    lexer->mark_end(lexer);
179                    return true;
180                default:
181                    return false;
182            }
183        }
184
185        char scanned_word[16] = {0};
186    _switch:
187        switch (lexer->lookahead) {
188            case ',':
189            case '.':
190            case ':':
191            case '*':
192            case '%':
193            case '>':
194            case '<':
195            case '=':
196            case '{':
197            case '[':
198            case '|':
199            case '&':
200            case '/':
201                return false;
202            // Insert a semicolon before `--` and `++`, but not before binary `+` or `-`.
203            // Insert before +/-{float}
204            case '+':
205                skip(lexer);
206                if (lexer->lookahead == '+') {
207                    return true;
208                }
209                return iswdigit(lexer->lookahead);
210            case '-':
211                skip(lexer);
212                if (lexer->lookahead == '-') {
213                    return true;
214                }
215                return iswdigit(lexer->lookahead);
216            // Don't insert a semicolon before `!=`, but do insert one before a unary `!`.
217            case '!':
218                skip(lexer);
219                if (lexer->lookahead == 'i' && valid_symbols[NOT_IS]) {
220                    skip(lexer);
221                    if (lexer->lookahead == 's') {
222                        skip(lexer);
223                        if (!iswalnum(lexer->lookahead)) {
224                            return true;
225                        }
226                    }
227                }
228                return lexer->lookahead != '=';
229            case '?':
230                if (valid_symbols[Q_DOT]) {
231                    goto q_dot_from_semi;
232                }
233                return true;
234            case 'e':
235            case 'i':
236            case 'g':
237            case 's':
238            case 'p':
239            case 'a':
240            case 'f':
241            case 'o':
242            case 'l':
243            case 'v':
244            case 'n':
245            case 'c':
246            case 'b':
247            case 'w':
248                while (scan_words(lexer,
249                                  (const char[16][16]){"public", "private", "protected", "internal", "abstract",
250                                                       "final", "open", "override", "lateinit", "vararg", "noinline",
251                                                       "crossinline", "external", "suspend", "inline"},
252                                  scanned_word, NULL)) {
253                    memset(scanned_word, 0, MAX_WORD_SIZE);
254                    while (iswspace(lexer->lookahead)) {
255                        skip(lexer);
256                    }
257                }
258
259                uint8_t index = -1;
260                bool res = scan_words(
261                    lexer,
262                    (const char[16][16]){"else", "in", "instanceof", "get", "set", "constructor", "by", "as", "where"},
263                    scanned_word, &index);
264
265                // If `CLASS_MEMBER_SEMI` is valid, we found a secondary constructor and so we want to insert a semi, OR
266                // we found a variable named constructor whose field is being accessed
267                if (index == 5) {
268                    while (iswspace(lexer->lookahead)) {
269                        skip(lexer);
270                    }
271                    if (valid_symbols[CLASS_MEMBER_SEMI] || lexer->lookahead == '.' || lexer->lookahead == '=') {
272                        return true;
273                    }
274                }
275                // Ordinarily, we should not insert a semicolon if there is an `else` on the next line,
276                // except for when it's a 'when entry', which has a `->` after the `else`.
277                else if (index == 0) {
278                    while (iswspace(lexer->lookahead)) {
279                        skip(lexer);
280                    }
281                    if (lexer->lookahead == '-') {
282                        skip(lexer);
283                        if (lexer->lookahead == '>') {
284                            return true;
285                        }
286                    }
287                }
288                // If `get` was found and the keyword is not valid, return a semi since it's being used as an identifier
289                else if (index == 3 && (!valid_symbols[GET] || lexer->lookahead == '[')) {
290                    return true;
291                }
292                // If `set` was found and the keyword is not valid, return a semi since it's being used as an identifier
293                else if (index == 4 && (!valid_symbols[SET] || lexer->lookahead == '[' || lexer->lookahead == '(' ||
294                                        lexer->lookahead == '.')) {
295                    if (lexer->lookahead == '(' && valid_symbols[SET]) {
296                        // skip until the closing parenthesis
297                        while (lexer->lookahead != ')' && !lexer->eof(lexer)) {
298                            skip(lexer);
299                        }
300                        skip(lexer);
301
302                        while (iswspace(lexer->lookahead)) {
303                            if (lexer->lookahead == '\n') {
304                                return true;
305                            }
306                            skip(lexer);
307                        }
308                        return false;
309                    }
310                    return true;
311                }
312                // If `in` was found and this specific external keyword is valid,
313                // return a semi since it's being used in a range test
314                else if (index == 1 && valid_symbols[IN]) {
315                    return true;
316                }
317                return !res;
318            case ';':
319                advance(lexer);
320                lexer->mark_end(lexer);
321                return true;
322            case '@':
323                if (valid_symbols[CONSTRUCTOR]) {
324                    while (!iswspace(lexer->lookahead)) {
325                        skip(lexer);
326                    }
327                    while (iswspace(lexer->lookahead)) {
328                        skip(lexer);
329                    }
330                    char ctor[12] = "constructor";
331                    for (uint8_t i = 0; i < 11; i++) {
332                        if (lexer->lookahead != ctor[i]) {
333                            return true;
334                        }
335                        skip(lexer);
336                    }
337                    return false;
338                }
339                if (valid_symbols[GET] || valid_symbols[SET]) {
340                    bool saw_paren = false;
341                    while ((saw_paren ? lexer->lookahead != '\n' : !iswspace(lexer->lookahead))) {
342                        skip(lexer);
343                        if (lexer->lookahead == '(') {
344                            saw_paren = true;
345                        }
346                        if (lexer->lookahead == ')') {
347                            saw_paren = false;
348                        }
349                    }
350                    while (iswspace(lexer->lookahead)) {
351                        skip(lexer);
352                    }
353                    if (lexer->lookahead == '/') {
354                        return true;
355                    }
356                    goto _switch;
357                }
358                return true;
359
360            default:
361                return true;
362        }
363    }
364
365    while (iswspace(lexer->lookahead)) {
366        skip(lexer);
367    }
368
369    if (valid_symbols[NOT_IS]) {
370        if (lexer->lookahead == '!') {
371            advance(lexer);
372        continue_not_is_from_semi:
373            if (lexer->lookahead == 'i') {
374                advance(lexer);
375                if (lexer->lookahead == 's') {
376                    advance(lexer);
377                    lexer->result_symbol = NOT_IS;
378                    lexer->mark_end(lexer);
379                    return !iswalnum(lexer->lookahead);
380                }
381            }
382        }
383    }
384
385    if (valid_symbols[IN]) {
386        if (lexer->lookahead == 'i') {
387            advance(lexer);
388            if (lexer->lookahead == 'n') {
389                advance(lexer);
390                lexer->result_symbol = IN;
391                lexer->mark_end(lexer);
392                return !iswalnum(lexer->lookahead);
393            }
394        }
395    }
396
397q_dot_from_semi:
398    if (valid_symbols[Q_DOT]) {
399        while (iswspace(lexer->lookahead)) {
400            skip(lexer);
401        }
402        if (lexer->lookahead == '?') {
403            advance(lexer);
404            while (iswspace(lexer->lookahead)) {
405                skip(lexer);
406            }
407            if (lexer->lookahead == '.') {
408                advance(lexer);
409                lexer->result_symbol = Q_DOT;
410                lexer->mark_end(lexer);
411                return true;
412            }
413        }
414    }
415
416comment:
417    if (valid_symbols[DOLLAR]) {
418        return false;
419    }
420
421    if (lexer->lookahead == '/') {
422        advance(lexer);
423        if (lexer->lookahead != '*') {
424            return false;
425        }
426        advance(lexer);
427
428        bool after_star = false;
429        unsigned nesting_depth = 1;
430        for (;;) {
431            switch (lexer->lookahead) {
432                case '\0':
433                    return false;
434                case '*':
435                    advance(lexer);
436                    after_star = true;
437                    break;
438                case '/':
439                    if (after_star) {
440                        advance(lexer);
441                        after_star = false;
442                        nesting_depth--;
443                        if (nesting_depth == 0) {
444                            lexer->result_symbol = BLOCK_COMMENT;
445                            lexer->mark_end(lexer);
446                            return true;
447                        }
448                    } else {
449                        advance(lexer);
450                        after_star = false;
451                        if (lexer->lookahead == '*') {
452                            nesting_depth++;
453                            advance(lexer);
454                        }
455                    }
456                    break;
457                default:
458                    advance(lexer);
459                    after_star = false;
460                    break;
461            }
462        }
463    }
464
465    return false;
466}