1// FIXME:
2// - Truncate longer argument list.
3
4#define _GNU_SOURCE
5#include <assert.h>
6#include <getopt.h>
7#include <pthread.h>
8#include <stdarg.h>
9#include <stdio.h>
10#include <stdlib.h>
11#include <string.h>
12
13#include <tree_sitter/api.h>
14
15#include "file.h"
16#include "list.h"
17#include "tpool.h"
18
19#include "queries/c.h"
20#include "queries/cpp.h"
21#include "queries/cuda.h"
22#include "queries/glsl.h"
23#include "queries/go.h"
24#include "queries/javascript.h"
25#include "queries/kotlin.h"
26#include "queries/lua.h"
27#include "queries/odin.h"
28#include "queries/php.h"
29#include "queries/python.h"
30#include "queries/rust.h"
31#include "queries/tcl.h"
32#include "queries/zig.h"
33
34int debug_enabled = 0;
35
36TSLanguage *tree_sitter_c(void);
37TSLanguage *tree_sitter_cpp(void);
38TSLanguage *tree_sitter_go(void);
39TSLanguage *tree_sitter_python(void);
40TSLanguage *tree_sitter_php(void);
41TSLanguage *tree_sitter_rust(void);
42TSLanguage *tree_sitter_javascript(void);
43TSLanguage *tree_sitter_lua(void);
44TSLanguage *tree_sitter_zig(void);
45TSLanguage *tree_sitter_kotlin(void);
46TSLanguage *tree_sitter_odin(void);
47TSLanguage *tree_sitter_tcl(void);
48TSLanguage *tree_sitter_glsl(void);
49TSLanguage *tree_sitter_cuda(void);
50
51#define MIN(a, b) ((a) < (b) ? (a) : (b))
52
53int levenshtein_distance(const char *s1, const char *s2) {
54 unsigned int len1 = strlen(s1);
55 unsigned int len2 = strlen(s2);
56 unsigned int distances[len1 + 1][len2 + 1];
57
58 for (unsigned int i = 0; i <= len1; i++) {
59 distances[i][0] = i;
60 }
61 for (unsigned int j = 0; j <= len2; j++) {
62 distances[0][j] = j;
63 }
64
65 for (unsigned int i = 1; i <= len1; i++) {
66 for (unsigned int j = 1; j <= len2; j++) {
67 int cost = (s1[i - 1] == s2[j - 1]) ? 0 : 1;
68 distances[i][j] = MIN(MIN(distances[i - 1][j] + 1, distances[i][j - 1] + 1), distances[i - 1][j - 1] + cost);
69 }
70 }
71
72 return distances[len1][len2];
73}
74
75typedef struct {
76 const char *fname;
77 const char *ftype;
78 const char *fparams;
79 size_t lineno;
80} Function;
81
82const char *extract_value(TSNode captured_node, const char *source_code) {
83 size_t start = ts_node_start_byte(captured_node);
84 size_t end = ts_node_end_byte(captured_node);
85 size_t length = end - start;
86 char *buffer = malloc(length + 1); // +1 for the null terminator
87
88 if (buffer != NULL) {
89 snprintf(buffer, length + 1, "%.*s", (int)length, &source_code[start]);
90 return buffer;
91 } else {
92 perror("malloc");
93 exit(EXIT_FAILURE);
94 }
95
96 return NULL;
97}
98
99char *remove_newlines(const char *str) {
100 if (str == NULL)
101 return NULL;
102 size_t length = strlen(str);
103 char *result = (char *)malloc(length + 1); // +1 for the null terminator
104 if (result == NULL) {
105 fprintf(stderr, "Memory allocation failed\n");
106 exit(1);
107 }
108
109 size_t j = 0;
110 for (size_t i = 0; i < length; i++) {
111 if (str[i] != '\n') {
112 result[j++] = str[i];
113 }
114 }
115
116 result[j] = '\0';
117 return result;
118}
119
120struct ThreadArgs {
121 const char *file_path;
122 const char *source_code;
123 TSLanguage *language;
124 const char *query_string;
125 uint32_t query_len;
126 const char *cfname;
127 int case_sensitive;
128 int max_distance;
129};
130
131// void parse_source_file(const char *file_path, const char *source_code,
132// TSLanguage *language, const char *cfname) {
133void parse_source_file(void *arg) {
134 struct ThreadArgs *args = (struct ThreadArgs *)arg;
135
136 const char *file_path = args->file_path;
137 const char *source_code = args->source_code;
138 TSLanguage *language = args->language;
139 const char *cfname = args->cfname;
140 int case_sensitive = args->case_sensitive;
141 int max_distance = args->max_distance;
142
143 TSParser *parser = ts_parser_new();
144 ts_parser_set_language(parser, language);
145
146 TSTree *tree = ts_parser_parse_string(parser, NULL, source_code, strlen(source_code));
147 if (tree == NULL) {
148 if (debug_enabled) {
149 fprintf(stderr, "Parsing failed for file: %s\n", file_path);
150 }
151 ts_parser_delete(parser);
152 free((void *)source_code);
153 free(args);
154 return;
155 }
156 TSNode root_node = ts_tree_root_node(tree);
157
158 const char *query_string = args->query_string;
159 uint32_t query_len = args->query_len;
160
161 uint32_t error_offset;
162 TSQueryError error_type;
163 TSQuery *query = ts_query_new(language, query_string, query_len, &error_offset, &error_type);
164
165 if (query == NULL) {
166 if (debug_enabled) {
167 printf("Query creation failed at offset %u with error type %d\n", error_offset, error_type);
168 }
169 ts_tree_delete(tree);
170 ts_parser_delete(parser);
171 free((void *)source_code);
172 free(args);
173 return;
174 }
175
176 TSQueryCursor *query_cursor = ts_query_cursor_new();
177 ts_query_cursor_exec(query_cursor, query, root_node);
178
179 TSQueryMatch match;
180 while (ts_query_cursor_next_match(query_cursor, &match)) {
181 Function fn = {0};
182
183 for (unsigned i = 0; i < match.capture_count; i++) {
184 TSQueryCapture capture = match.captures[i];
185 TSNode captured_node = capture.node;
186
187 uint32_t capture_name_length;
188 const char *capture_name = ts_query_capture_name_for_id(
189 query, capture.index, &capture_name_length);
190
191 if (strcmp(capture_name, "fname") == 0) {
192 fn.fname = extract_value(captured_node, source_code);
193
194 TSPoint start_point = ts_node_start_point(captured_node);
195 fn.lineno = start_point.row + 1;
196 }
197
198 if (strcmp(capture_name, "ftype") == 0) {
199 fn.ftype = extract_value(captured_node, source_code);
200 }
201
202 if (strcmp(capture_name, "fparams") == 0) {
203 fn.fparams = extract_value(captured_node, source_code);
204 }
205 }
206
207 // Substring matching.
208 if (fn.fname != NULL) {
209 char *result = NULL;
210 int distance = -1;
211
212 if (max_distance > 0) {
213 distance = levenshtein_distance(fn.fname, cfname);
214 if (distance <= max_distance) {
215 // We treat it as a match, but result pointer logic is different
216 // For printing purposes effectively a match.
217 // We'll just set result to non-null to trigger the print.
218 result = (char *)fn.fname;
219 }
220 } else {
221 if (case_sensitive) {
222 result = strstr(fn.fname, cfname);
223 } else {
224 result = strcasestr(fn.fname, cfname);
225 }
226 }
227
228 if (result != NULL) {
229 char *fparams_formatted = remove_newlines(fn.fparams);
230 if (max_distance > 0) {
231 printf("%s:%zu: %s %s %s (dist: %d)\n", file_path, fn.lineno, fn.ftype ? fn.ftype : "", fn.fname, fparams_formatted ? fparams_formatted : "", distance);
232 } else {
233 printf("%s:%zu: %s %s %s\n", file_path, fn.lineno, fn.ftype ? fn.ftype : "", fn.fname, fparams_formatted ? fparams_formatted : "");
234 }
235 free(fparams_formatted);
236 }
237 }
238
239 // Free captured values
240 free((void *)fn.fname);
241 free((void *)fn.ftype);
242 free((void *)fn.fparams);
243 }
244
245 ts_query_cursor_delete(query_cursor);
246 ts_query_delete(query);
247 ts_tree_delete(tree);
248 ts_parser_delete(parser);
249
250 // Cleanup thread arguments
251 free((void *)source_code);
252 free(args);
253}
254
255const char *get_file_extension(const char *file_path) {
256 const char *extension = strrchr(file_path, '.');
257 if (extension != NULL) {
258 return extension + 1;
259 }
260 return NULL;
261}
262
263int main(int argc, char *argv[]) {
264 int case_sensitive = 0;
265 int max_distance = 0;
266 int max_depth = -1;
267 int opt;
268 struct option long_options[] = {
269 {"case-sensitive", no_argument, 0, 'c'},
270 {"levenshtein", required_argument, 0, 'l'},
271 {"depth", required_argument, 0, 'd'},
272 {0, 0, 0, 0}};
273
274 while ((opt = getopt_long(argc, argv, "cl:d:", long_options, NULL)) != -1) {
275 switch (opt) {
276 case 'c':
277 case_sensitive = 1;
278 break;
279 case 'l':
280 max_distance = atoi(optarg);
281 break;
282 case 'd':
283 max_depth = atoi(optarg);
284 break;
285 default:
286 fprintf(stderr, "Usage: %s [-c|--case-sensitive] [-l|--levenshtein <dist>] [-d|--depth <level>] <search term> [directory|file]\n", argv[0]);
287 return 1;
288 }
289 }
290
291 if (optind >= argc) {
292 fprintf(stderr, "Usage: %s [-c|--case-sensitive] [-l|--levenshtein <dist>] [-d|--depth <level>] <search term> [directory|file]\n", argv[0]);
293 return 1;
294 }
295
296 const char *cfname = argv[optind];
297 char *directory = (optind + 1 < argc) ? argv[optind + 1] : ".";
298
299 Node *head = NULL;
300 list_files_recursively(directory, &head, max_depth, 0);
301 int list_size = size_of_file_list(head);
302
303 const char *debug_env = getenv("DEBUG");
304 if (debug_env != NULL && (strcmp(debug_env, "1") == 0 || strcmp(debug_env, "true") == 0)) {
305 debug_enabled = 1;
306 }
307
308 if (debug_enabled) {
309 printf("Scanning %d files\n", list_size);
310 }
311
312 ThreadPool *pool = tp_create(8);
313 if (!pool) {
314 perror("Failed to create thread pool");
315 return 1;
316 }
317
318 Node *current = head;
319 while (current != NULL) {
320 const char *file_path = current->file_path;
321 const char *extension = get_file_extension(file_path);
322
323 TSLanguage *lang = NULL;
324 const char *query_string = NULL;
325 uint32_t query_len = 0;
326
327 if (extension != NULL) {
328 if (strcmp(extension, "c") == 0 || strcmp(extension, "h") == 0) {
329 lang = tree_sitter_c();
330 query_string = (const char *)query_c;
331 query_len = query_c_len;
332 } else if (strcmp(extension, "cpp") == 0 || strcmp(extension, "hpp") == 0) {
333 lang = tree_sitter_cpp();
334 query_string = (const char *)query_cpp;
335 query_len = query_cpp_len;
336 } else if (strcmp(extension, "go") == 0) {
337 lang = tree_sitter_go();
338 query_string = (const char *)query_go;
339 query_len = query_go_len;
340 } else if (strcmp(extension, "py") == 0) {
341 lang = tree_sitter_python();
342 query_string = (const char *)query_python;
343 query_len = query_python_len;
344 } else if (strcmp(extension, "php") == 0) {
345 lang = tree_sitter_php();
346 query_string = (const char *)query_php;
347 query_len = query_php_len;
348 } else if (strcmp(extension, "rs") == 0) {
349 lang = tree_sitter_rust();
350 query_string = (const char *)query_rust;
351 query_len = query_rust_len;
352 } else if (strcmp(extension, "js") == 0) {
353 lang = tree_sitter_javascript();
354 query_string = (const char *)query_javascript;
355 query_len = query_javascript_len;
356 } else if (strcmp(extension, "lua") == 0) {
357 lang = tree_sitter_lua();
358 query_string = (const char *)query_lua;
359 query_len = query_lua_len;
360 } else if (strcmp(extension, "zig") == 0) {
361 lang = tree_sitter_zig();
362 query_string = (const char *)query_zig;
363 query_len = query_zig_len;
364 } else if (strcmp(extension, "kt") == 0) {
365 lang = tree_sitter_kotlin();
366 query_string = (const char *)query_kotlin;
367 query_len = query_kotlin_len;
368 } else if (strcmp(extension, "odin") == 0) {
369 lang = tree_sitter_odin();
370 query_string = (const char *)query_odin;
371 query_len = query_odin_len;
372 } else if (strcmp(extension, "tcl") == 0) {
373 lang = tree_sitter_tcl();
374 query_string = (const char *)query_tcl;
375 query_len = query_tcl_len;
376 } else if (strcmp(extension, "glsl") == 0) {
377 lang = tree_sitter_glsl();
378 query_string = (const char *)query_glsl;
379 query_len = query_glsl_len;
380 } else if (strcmp(extension, "cu") == 0 || strcmp(extension, "cuh") == 0) {
381 lang = tree_sitter_cuda();
382 query_string = (const char *)query_cuda;
383 query_len = query_cuda_len;
384 }
385 }
386
387 if (lang != NULL && query_string != NULL) {
388 struct FileContent source_file = read_entire_file(file_path);
389 if (source_file.content != NULL) {
390 struct ThreadArgs *thread_args = malloc(sizeof(struct ThreadArgs));
391 if (!thread_args) {
392 perror("Failed to allocate thread args");
393 free((void *)source_file.content);
394 continue;
395 }
396
397 thread_args->file_path = file_path;
398 thread_args->source_code = source_file.content;
399 thread_args->language = lang;
400 thread_args->query_string = query_string;
401 thread_args->query_len = query_len;
402 thread_args->cfname = cfname;
403 thread_args->case_sensitive = case_sensitive;
404 thread_args->max_distance = max_distance;
405
406 tp_add_job(pool, (thread_func_t)parse_source_file, thread_args);
407 } else {
408 if (debug_enabled) {
409 fprintf(stderr, "Failed to read file: %s\n", file_path);
410 }
411 }
412 }
413
414 current = current->next;
415 }
416
417 tp_wait(pool);
418 tp_destroy(pool);
419 free_file_list(head);
420 return 0;
421}