1// FIXME:
  2//  - Truncate longer argument list.
  3
  4#define _GNU_SOURCE
  5#include <assert.h>
  6#include <getopt.h>
  7#include <pthread.h>
  8#include <stdarg.h>
  9#include <stdio.h>
 10#include <stdlib.h>
 11#include <string.h>
 12
 13#include <tree_sitter/api.h>
 14
 15#include "file.h"
 16#include "list.h"
 17#include "tpool.h"
 18
 19#include "queries/c.h"
 20#include "queries/cpp.h"
 21#include "queries/cuda.h"
 22#include "queries/glsl.h"
 23#include "queries/go.h"
 24#include "queries/javascript.h"
 25#include "queries/kotlin.h"
 26#include "queries/lua.h"
 27#include "queries/odin.h"
 28#include "queries/php.h"
 29#include "queries/python.h"
 30#include "queries/rust.h"
 31#include "queries/tcl.h"
 32#include "queries/zig.h"
 33
 34int debug_enabled = 0;
 35
 36TSLanguage *tree_sitter_c(void);
 37TSLanguage *tree_sitter_cpp(void);
 38TSLanguage *tree_sitter_go(void);
 39TSLanguage *tree_sitter_python(void);
 40TSLanguage *tree_sitter_php(void);
 41TSLanguage *tree_sitter_rust(void);
 42TSLanguage *tree_sitter_javascript(void);
 43TSLanguage *tree_sitter_lua(void);
 44TSLanguage *tree_sitter_zig(void);
 45TSLanguage *tree_sitter_kotlin(void);
 46TSLanguage *tree_sitter_odin(void);
 47TSLanguage *tree_sitter_tcl(void);
 48TSLanguage *tree_sitter_glsl(void);
 49TSLanguage *tree_sitter_cuda(void);
 50
 51#define MIN(a, b) ((a) < (b) ? (a) : (b))
 52
 53int levenshtein_distance(const char *s1, const char *s2) {
 54	unsigned int len1 = strlen(s1);
 55	unsigned int len2 = strlen(s2);
 56	unsigned int distances[len1 + 1][len2 + 1];
 57
 58	for (unsigned int i = 0; i <= len1; i++) {
 59		distances[i][0] = i;
 60	}
 61	for (unsigned int j = 0; j <= len2; j++) {
 62		distances[0][j] = j;
 63	}
 64
 65	for (unsigned int i = 1; i <= len1; i++) {
 66		for (unsigned int j = 1; j <= len2; j++) {
 67			int cost = (s1[i - 1] == s2[j - 1]) ? 0 : 1;
 68			distances[i][j] = MIN(MIN(distances[i - 1][j] + 1, distances[i][j - 1] + 1), distances[i - 1][j - 1] + cost);
 69		}
 70	}
 71
 72	return distances[len1][len2];
 73}
 74
 75typedef struct {
 76	const char *fname;
 77	const char *ftype;
 78	const char *fparams;
 79	size_t lineno;
 80} Function;
 81
 82const char *extract_value(TSNode captured_node, const char *source_code) {
 83	size_t start = ts_node_start_byte(captured_node);
 84	size_t end = ts_node_end_byte(captured_node);
 85	size_t length = end - start;
 86	char *buffer = malloc(length + 1); // +1 for the null terminator
 87
 88	if (buffer != NULL) {
 89		snprintf(buffer, length + 1, "%.*s", (int)length, &source_code[start]);
 90		return buffer;
 91	} else {
 92		perror("malloc");
 93		exit(EXIT_FAILURE);
 94	}
 95
 96	return NULL;
 97}
 98
 99char *remove_newlines(const char *str) {
100	if (str == NULL)
101		return NULL;
102	size_t length = strlen(str);
103	char *result = (char *)malloc(length + 1); // +1 for the null terminator
104	if (result == NULL) {
105		fprintf(stderr, "Memory allocation failed\n");
106		exit(1);
107	}
108
109	size_t j = 0;
110	for (size_t i = 0; i < length; i++) {
111		if (str[i] != '\n') {
112			result[j++] = str[i];
113		}
114	}
115
116	result[j] = '\0';
117	return result;
118}
119
120struct ThreadArgs {
121	const char *file_path;
122	const char *source_code;
123	TSLanguage *language;
124	const char *query_string;
125	uint32_t query_len;
126	const char *cfname;
127	int case_sensitive;
128	int max_distance;
129};
130
131// void parse_source_file(const char *file_path, const char *source_code,
132// TSLanguage *language, const char *cfname) {
133void parse_source_file(void *arg) {
134	struct ThreadArgs *args = (struct ThreadArgs *)arg;
135
136	const char *file_path = args->file_path;
137	const char *source_code = args->source_code;
138	TSLanguage *language = args->language;
139	const char *cfname = args->cfname;
140	int case_sensitive = args->case_sensitive;
141	int max_distance = args->max_distance;
142
143	TSParser *parser = ts_parser_new();
144	ts_parser_set_language(parser, language);
145
146	TSTree *tree = ts_parser_parse_string(parser, NULL, source_code, strlen(source_code));
147	if (tree == NULL) {
148		if (debug_enabled) {
149			fprintf(stderr, "Parsing failed for file: %s\n", file_path);
150		}
151		ts_parser_delete(parser);
152		free((void *)source_code);
153		free(args);
154		return;
155	}
156	TSNode root_node = ts_tree_root_node(tree);
157
158	const char *query_string = args->query_string;
159	uint32_t query_len = args->query_len;
160
161	uint32_t error_offset;
162	TSQueryError error_type;
163	TSQuery *query = ts_query_new(language, query_string, query_len, &error_offset, &error_type);
164
165	if (query == NULL) {
166		if (debug_enabled) {
167			printf("Query creation failed at offset %u with error type %d\n", error_offset, error_type);
168		}
169		ts_tree_delete(tree);
170		ts_parser_delete(parser);
171		free((void *)source_code);
172		free(args);
173		return;
174	}
175
176	TSQueryCursor *query_cursor = ts_query_cursor_new();
177	ts_query_cursor_exec(query_cursor, query, root_node);
178
179	TSQueryMatch match;
180	while (ts_query_cursor_next_match(query_cursor, &match)) {
181		Function fn = {0};
182
183		for (unsigned i = 0; i < match.capture_count; i++) {
184			TSQueryCapture capture = match.captures[i];
185			TSNode captured_node = capture.node;
186
187			uint32_t capture_name_length;
188			const char *capture_name = ts_query_capture_name_for_id(
189				query, capture.index, &capture_name_length);
190
191			if (strcmp(capture_name, "fname") == 0) {
192				fn.fname = extract_value(captured_node, source_code);
193
194				TSPoint start_point = ts_node_start_point(captured_node);
195				fn.lineno = start_point.row + 1;
196			}
197
198			if (strcmp(capture_name, "ftype") == 0) {
199				fn.ftype = extract_value(captured_node, source_code);
200			}
201
202			if (strcmp(capture_name, "fparams") == 0) {
203				fn.fparams = extract_value(captured_node, source_code);
204			}
205		}
206
207		// Substring matching.
208		if (fn.fname != NULL) {
209			char *result = NULL;
210			int distance = -1;
211
212			if (max_distance > 0) {
213				distance = levenshtein_distance(fn.fname, cfname);
214				if (distance <= max_distance) {
215					// We treat it as a match, but result pointer logic is different
216					// For printing purposes effectively a match.
217					// We'll just set result to non-null to trigger the print.
218					result = (char *)fn.fname;
219				}
220			} else {
221				if (case_sensitive) {
222					result = strstr(fn.fname, cfname);
223				} else {
224					result = strcasestr(fn.fname, cfname);
225				}
226			}
227
228			if (result != NULL) {
229				char *fparams_formatted = remove_newlines(fn.fparams);
230				if (max_distance > 0) {
231					printf("%s:%zu: %s %s %s (dist: %d)\n", file_path, fn.lineno, fn.ftype ? fn.ftype : "", fn.fname, fparams_formatted ? fparams_formatted : "", distance);
232				} else {
233					printf("%s:%zu: %s %s %s\n", file_path, fn.lineno, fn.ftype ? fn.ftype : "", fn.fname, fparams_formatted ? fparams_formatted : "");
234				}
235				free(fparams_formatted);
236			}
237		}
238
239		// Free captured values
240		free((void *)fn.fname);
241		free((void *)fn.ftype);
242		free((void *)fn.fparams);
243	}
244
245	ts_query_cursor_delete(query_cursor);
246	ts_query_delete(query);
247	ts_tree_delete(tree);
248	ts_parser_delete(parser);
249
250	// Cleanup thread arguments
251	free((void *)source_code);
252	free(args);
253}
254
255const char *get_file_extension(const char *file_path) {
256	const char *extension = strrchr(file_path, '.');
257	if (extension != NULL) {
258		return extension + 1;
259	}
260	return NULL;
261}
262
263int main(int argc, char *argv[]) {
264	int case_sensitive = 0;
265	int max_distance = 0;
266	int max_depth = -1;
267	int opt;
268	struct option long_options[] = {
269		{"case-sensitive", no_argument, 0, 'c'},
270		{"levenshtein", required_argument, 0, 'l'},
271		{"depth", required_argument, 0, 'd'},
272		{0, 0, 0, 0}};
273
274	while ((opt = getopt_long(argc, argv, "cl:d:", long_options, NULL)) != -1) {
275		switch (opt) {
276		case 'c':
277			case_sensitive = 1;
278			break;
279		case 'l':
280			max_distance = atoi(optarg);
281			break;
282		case 'd':
283			max_depth = atoi(optarg);
284			break;
285		default:
286			fprintf(stderr, "Usage: %s [-c|--case-sensitive] [-l|--levenshtein <dist>] [-d|--depth <level>] <search term> [directory|file]\n", argv[0]);
287			return 1;
288		}
289	}
290
291	if (optind >= argc) {
292		fprintf(stderr, "Usage: %s [-c|--case-sensitive] [-l|--levenshtein <dist>] [-d|--depth <level>] <search term> [directory|file]\n", argv[0]);
293		return 1;
294	}
295
296	const char *cfname = argv[optind];
297	char *directory = (optind + 1 < argc) ? argv[optind + 1] : ".";
298
299	Node *head = NULL;
300	list_files_recursively(directory, &head, max_depth, 0);
301	int list_size = size_of_file_list(head);
302
303	const char *debug_env = getenv("DEBUG");
304	if (debug_env != NULL && (strcmp(debug_env, "1") == 0 || strcmp(debug_env, "true") == 0)) {
305		debug_enabled = 1;
306	}
307
308	if (debug_enabled) {
309		printf("Scanning %d files\n", list_size);
310	}
311
312	ThreadPool *pool = tp_create(8);
313	if (!pool) {
314		perror("Failed to create thread pool");
315		return 1;
316	}
317
318	Node *current = head;
319	while (current != NULL) {
320		const char *file_path = current->file_path;
321		const char *extension = get_file_extension(file_path);
322
323		TSLanguage *lang = NULL;
324		const char *query_string = NULL;
325		uint32_t query_len = 0;
326
327		if (extension != NULL) {
328			if (strcmp(extension, "c") == 0 || strcmp(extension, "h") == 0) {
329				lang = tree_sitter_c();
330				query_string = (const char *)query_c;
331				query_len = query_c_len;
332			} else if (strcmp(extension, "cpp") == 0 || strcmp(extension, "hpp") == 0) {
333				lang = tree_sitter_cpp();
334				query_string = (const char *)query_cpp;
335				query_len = query_cpp_len;
336			} else if (strcmp(extension, "go") == 0) {
337				lang = tree_sitter_go();
338				query_string = (const char *)query_go;
339				query_len = query_go_len;
340			} else if (strcmp(extension, "py") == 0) {
341				lang = tree_sitter_python();
342				query_string = (const char *)query_python;
343				query_len = query_python_len;
344			} else if (strcmp(extension, "php") == 0) {
345				lang = tree_sitter_php();
346				query_string = (const char *)query_php;
347				query_len = query_php_len;
348			} else if (strcmp(extension, "rs") == 0) {
349				lang = tree_sitter_rust();
350				query_string = (const char *)query_rust;
351				query_len = query_rust_len;
352			} else if (strcmp(extension, "js") == 0) {
353				lang = tree_sitter_javascript();
354				query_string = (const char *)query_javascript;
355				query_len = query_javascript_len;
356			} else if (strcmp(extension, "lua") == 0) {
357				lang = tree_sitter_lua();
358				query_string = (const char *)query_lua;
359				query_len = query_lua_len;
360			} else if (strcmp(extension, "zig") == 0) {
361				lang = tree_sitter_zig();
362				query_string = (const char *)query_zig;
363				query_len = query_zig_len;
364			} else if (strcmp(extension, "kt") == 0) {
365				lang = tree_sitter_kotlin();
366				query_string = (const char *)query_kotlin;
367				query_len = query_kotlin_len;
368			} else if (strcmp(extension, "odin") == 0) {
369				lang = tree_sitter_odin();
370				query_string = (const char *)query_odin;
371				query_len = query_odin_len;
372			} else if (strcmp(extension, "tcl") == 0) {
373				lang = tree_sitter_tcl();
374				query_string = (const char *)query_tcl;
375				query_len = query_tcl_len;
376			} else if (strcmp(extension, "glsl") == 0) {
377				lang = tree_sitter_glsl();
378				query_string = (const char *)query_glsl;
379				query_len = query_glsl_len;
380			} else if (strcmp(extension, "cu") == 0 || strcmp(extension, "cuh") == 0) {
381				lang = tree_sitter_cuda();
382				query_string = (const char *)query_cuda;
383				query_len = query_cuda_len;
384			}
385		}
386
387		if (lang != NULL && query_string != NULL) {
388			struct FileContent source_file = read_entire_file(file_path);
389			if (source_file.content != NULL) {
390				struct ThreadArgs *thread_args = malloc(sizeof(struct ThreadArgs));
391				if (!thread_args) {
392					perror("Failed to allocate thread args");
393					free((void *)source_file.content);
394					continue;
395				}
396
397				thread_args->file_path = file_path;
398				thread_args->source_code = source_file.content;
399				thread_args->language = lang;
400				thread_args->query_string = query_string;
401				thread_args->query_len = query_len;
402				thread_args->cfname = cfname;
403				thread_args->case_sensitive = case_sensitive;
404				thread_args->max_distance = max_distance;
405
406				tp_add_job(pool, (thread_func_t)parse_source_file, thread_args);
407			} else {
408				if (debug_enabled) {
409					fprintf(stderr, "Failed to read file: %s\n", file_path);
410				}
411			}
412		}
413
414		current = current->next;
415	}
416
417	tp_wait(pool);
418	tp_destroy(pool);
419	free_file_list(head);
420	return 0;
421}