1#ifndef PRE_WGSL_HPP
2#define PRE_WGSL_HPP
3
4#include <cctype>
5#include <fstream>
6#include <sstream>
7#include <stdexcept>
8#include <string>
9#include <string_view>
10#include <unordered_map>
11#include <unordered_set>
12#include <vector>
13
14namespace pre_wgsl {
15
16//==============================================================
17// Options
18//==============================================================
19struct Options {
20 std::string include_path = ".";
21 std::vector<std::string> macros;
22};
23
24//==============================================================
25// Utility: trim
26//==============================================================
27static std::string trim(const std::string & s) {
28 size_t a = 0;
29 while (a < s.size() && std::isspace((unsigned char) s[a])) {
30 a++;
31 }
32 size_t b = s.size();
33 while (b > a && std::isspace((unsigned char) s[b - 1])) {
34 b--;
35 }
36 return s.substr(a, b - a);
37}
38
39static std::string trim_value(std::istream & is) {
40 std::string str;
41 std::getline(is, str);
42 return trim(str);
43}
44
45static bool isIdentChar(char c) {
46 return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
47}
48
49static std::string expandMacrosRecursiveInternal(const std::string & line,
50 const std::unordered_map<std::string, std::string> & macros,
51 std::unordered_set<std::string> & visiting);
52
53static std::string expandMacroValue(const std::string & name,
54 const std::unordered_map<std::string, std::string> & macros,
55 std::unordered_set<std::string> & visiting) {
56 if (visiting.count(name)) {
57 throw std::runtime_error("Recursive macro: " + name);
58 }
59 visiting.insert(name);
60
61 auto it = macros.find(name);
62 if (it == macros.end()) {
63 visiting.erase(name);
64 return name;
65 }
66
67 const std::string & value = it->second;
68 if (value.empty()) {
69 visiting.erase(name);
70 return "";
71 }
72
73 std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
74 visiting.erase(name);
75 return expanded;
76}
77
78static std::string expandMacrosRecursiveInternal(const std::string & line,
79 const std::unordered_map<std::string, std::string> & macros,
80 std::unordered_set<std::string> & visiting) {
81 std::string result;
82 result.reserve(line.size());
83
84 size_t i = 0;
85 while (i < line.size()) {
86 if (isIdentChar(line[i])) {
87 size_t start = i;
88 while (i < line.size() && isIdentChar(line[i])) {
89 i++;
90 }
91 std::string token = line.substr(start, i - start);
92
93 auto it = macros.find(token);
94 if (it != macros.end()) {
95 result += expandMacroValue(token, macros, visiting);
96 } else {
97 result += token;
98 }
99 } else {
100 result += line[i];
101 i++;
102 }
103 }
104
105 return result;
106}
107
108static std::string expandMacrosRecursive(const std::string & line,
109 const std::unordered_map<std::string, std::string> & macros) {
110 std::unordered_set<std::string> visiting;
111 return expandMacrosRecursiveInternal(line, macros, visiting);
112}
113
114//==============================================================
115// Tokenizer for expressions in #if/#elif
116//==============================================================
117class ExprLexer {
118 public:
119 enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
120
121 struct Tok {
122 Kind kind;
123 std::string text;
124 };
125
126 explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
127
128 Tok next() {
129 skipWS();
130 if (pos >= src.size()) {
131 return { END, "" };
132 }
133
134 char c = src[pos];
135
136 // number
137 if (std::isdigit((unsigned char) c)) {
138 size_t start = pos;
139 while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
140 pos++;
141 }
142 return { NUMBER, std::string(src.substr(start, pos - start)) };
143 }
144
145 // identifier
146 if (std::isalpha((unsigned char) c) || c == '_') {
147 size_t start = pos;
148 while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
149 pos++;
150 }
151 return { IDENT, std::string(src.substr(start, pos - start)) };
152 }
153
154 if (c == '(') {
155 pos++;
156 return { LPAREN, "(" };
157 }
158 if (c == ')') {
159 pos++;
160 return { RPAREN, ")" };
161 }
162
163 // multi-char operators
164 static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
165 for (auto op : two_ops) {
166 if (src.substr(pos, 2) == op) {
167 pos += 2;
168 return { OP, std::string(op) };
169 }
170 }
171
172 // single-char operators
173 if (std::string("+-*/%<>!").find(c) != std::string::npos) {
174 pos++;
175 return { OP, std::string(1, c) };
176 }
177
178 // unexpected
179 pos++;
180 return { END, "" };
181 }
182
183 private:
184 std::string_view src;
185 size_t pos;
186
187 void skipWS() {
188 while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
189 pos++;
190 }
191 }
192};
193
194//==============================================================
195// Expression Parser (recursive descent)
196//==============================================================
197class ExprParser {
198 public:
199 ExprParser(std::string_view expr,
200 const std::unordered_map<std::string, std::string> & macros,
201 std::unordered_set<std::string> & visiting) :
202 lex(expr),
203 macros(macros),
204 visiting(visiting) {
205 advance();
206 }
207
208 int parse() { return parseLogicalOr(); }
209
210 private:
211 ExprLexer lex;
212 ExprLexer::Tok tok;
213 const std::unordered_map<std::string, std::string> & macros;
214 std::unordered_set<std::string> & visiting;
215
216 void advance() { tok = lex.next(); }
217
218 bool acceptOp(const std::string & s) {
219 if (tok.kind == ExprLexer::OP && tok.text == s) {
220 advance();
221 return true;
222 }
223 return false;
224 }
225
226 bool acceptKind(ExprLexer::Kind k) {
227 if (tok.kind == k) {
228 advance();
229 return true;
230 }
231 return false;
232 }
233
234 int parseLogicalOr() {
235 int v = parseLogicalAnd();
236 while (acceptOp("||")) {
237 int rhs = parseLogicalAnd();
238 v = (v || rhs);
239 }
240 return v;
241 }
242
243 int parseLogicalAnd() {
244 int v = parseEquality();
245 while (acceptOp("&&")) {
246 int rhs = parseEquality();
247 v = (v && rhs);
248 }
249 return v;
250 }
251
252 int parseEquality() {
253 int v = parseRelational();
254 for (;;) {
255 if (acceptOp("==")) {
256 int rhs = parseRelational();
257 v = (v == rhs);
258 } else if (acceptOp("!=")) {
259 int rhs = parseRelational();
260 v = (v != rhs);
261 } else {
262 break;
263 }
264 }
265 return v;
266 }
267
268 int parseRelational() {
269 int v = parseShift();
270 for (;;) {
271 if (acceptOp("<")) {
272 int rhs = parseShift();
273 v = (v < rhs);
274 } else if (acceptOp(">")) {
275 int rhs = parseShift();
276 v = (v > rhs);
277 } else if (acceptOp("<=")) {
278 int rhs = parseShift();
279 v = (v <= rhs);
280 } else if (acceptOp(">=")) {
281 int rhs = parseShift();
282 v = (v >= rhs);
283 } else {
284 break;
285 }
286 }
287 return v;
288 }
289
290 int parseShift() {
291 int v = parseAdd();
292 for (;;) {
293 if (acceptOp("<<")) {
294 int rhs = parseAdd();
295 v = (v << rhs);
296 } else if (acceptOp(">>")) {
297 int rhs = parseAdd();
298 v = (v >> rhs);
299 } else {
300 break;
301 }
302 }
303 return v;
304 }
305
306 int parseAdd() {
307 int v = parseMult();
308 for (;;) {
309 if (acceptOp("+")) {
310 int rhs = parseMult();
311 v = (v + rhs);
312 } else if (acceptOp("-")) {
313 int rhs = parseMult();
314 v = (v - rhs);
315 } else {
316 break;
317 }
318 }
319 return v;
320 }
321
322 int parseMult() {
323 int v = parseUnary();
324 for (;;) {
325 if (acceptOp("*")) {
326 int rhs = parseUnary();
327 v = (v * rhs);
328 } else if (acceptOp("/")) {
329 int rhs = parseUnary();
330 v = (rhs == 0 ? 0 : v / rhs);
331 } else if (acceptOp("%")) {
332 int rhs = parseUnary();
333 v = (rhs == 0 ? 0 : v % rhs);
334 } else {
335 break;
336 }
337 }
338 return v;
339 }
340
341 int parseUnary() {
342 if (acceptOp("!")) {
343 return !parseUnary();
344 }
345 if (acceptOp("-")) {
346 return -parseUnary();
347 }
348 if (acceptOp("+")) {
349 return +parseUnary();
350 }
351 return parsePrimary();
352 }
353
354 int parsePrimary() {
355 // '(' expr ')'
356 if (acceptKind(ExprLexer::LPAREN)) {
357 int v = parse();
358 if (!acceptKind(ExprLexer::RPAREN)) {
359 throw std::runtime_error("missing ')'");
360 }
361 return v;
362 }
363
364 // number
365 if (tok.kind == ExprLexer::NUMBER) {
366 int v = std::stoi(tok.text);
367 advance();
368 return v;
369 }
370
371 // defined(identifier)
372 if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
373 advance();
374 if (acceptKind(ExprLexer::LPAREN)) {
375 if (tok.kind != ExprLexer::IDENT) {
376 throw std::runtime_error("expected identifier in defined()");
377 }
378 std::string name = tok.text;
379 advance();
380 if (!acceptKind(ExprLexer::RPAREN)) {
381 throw std::runtime_error("missing ) in defined()");
382 }
383 return macros.count(name) ? 1 : 0;
384 } else {
385 // defined NAME
386 if (tok.kind != ExprLexer::IDENT) {
387 throw std::runtime_error("expected identifier in defined NAME");
388 }
389 std::string name = tok.text;
390 advance();
391 return macros.count(name) ? 1 : 0;
392 }
393 }
394
395 // identifier -> treat as integer, if defined use its value else 0
396 if (tok.kind == ExprLexer::IDENT) {
397 std::string name = tok.text;
398 advance();
399 auto it = macros.find(name);
400 if (it == macros.end()) {
401 return 0;
402 }
403 if (it->second.empty()) {
404 return 1;
405 }
406 return evalMacroExpression(name, it->second);
407 }
408
409 // unexpected
410 return 0;
411 }
412
413 int evalMacroExpression(const std::string & name, const std::string & value) {
414 if (visiting.count(name)) {
415 throw std::runtime_error("Recursive macro: " + name);
416 }
417
418 visiting.insert(name);
419 ExprParser ep(value, macros, visiting);
420 int v = ep.parse();
421 visiting.erase(name);
422 return v;
423 }
424};
425
426//==============================================================
427// Preprocessor
428//==============================================================
429class Preprocessor {
430 public:
431 explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
432 // Treat empty include path as current directory
433 if (opts_.include_path.empty()) {
434 opts_.include_path = ".";
435 }
436 parseMacroDefinitions(opts_.macros);
437 }
438
439 std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
440 std::unordered_map<std::string, std::string> macros;
441 std::unordered_set<std::string> predefined;
442 std::unordered_set<std::string> include_stack;
443 buildMacros(additional_macros, macros, predefined);
444
445 std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
446 return result;
447 }
448
449 std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
450 std::unordered_map<std::string, std::string> macros;
451 std::unordered_set<std::string> predefined;
452 std::unordered_set<std::string> include_stack;
453 buildMacros(additional_macros, macros, predefined);
454
455 std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
456 return result;
457 }
458
459 std::string preprocess_includes_file(const std::string & filename) {
460 std::unordered_map<std::string, std::string> macros;
461 std::unordered_set<std::string> predefined;
462 std::unordered_set<std::string> include_stack;
463 std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
464 return result;
465 }
466
467 std::string preprocess_includes(const std::string & contents) {
468 std::unordered_map<std::string, std::string> macros;
469 std::unordered_set<std::string> predefined;
470 std::unordered_set<std::string> include_stack;
471 std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
472 return result;
473 }
474
475 private:
476 Options opts_;
477 std::unordered_map<std::string, std::string> global_macros;
478
479 enum class DirectiveMode { All, IncludesOnly };
480
481 struct Cond {
482 bool parent_active;
483 bool active;
484 bool taken;
485 };
486
487 //----------------------------------------------------------
488 // Parse macro definitions into global_macros
489 //----------------------------------------------------------
490 void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
491 for (const auto & def : macro_defs) {
492 size_t eq_pos = def.find('=');
493 if (eq_pos != std::string::npos) {
494 // Format: NAME=VALUE
495 std::string name = trim(def.substr(0, eq_pos));
496 std::string value = trim(def.substr(eq_pos + 1));
497 global_macros[name] = value;
498 } else {
499 // Format: NAME
500 std::string name = trim(def);
501 global_macros[name] = "";
502 }
503 }
504 }
505
506 //----------------------------------------------------------
507 // Build combined macro map and predefined set for a preprocessing operation
508 //----------------------------------------------------------
509 void buildMacros(const std::vector<std::string> & additional_macros,
510 std::unordered_map<std::string, std::string> & macros,
511 std::unordered_set<std::string> & predefined) {
512 macros = global_macros;
513 predefined.clear();
514
515 for (const auto & [name, value] : global_macros) {
516 predefined.insert(name);
517 }
518
519 for (const auto & def : additional_macros) {
520 size_t eq_pos = def.find('=');
521 std::string name, value;
522 if (eq_pos != std::string::npos) {
523 name = trim(def.substr(0, eq_pos));
524 value = trim(def.substr(eq_pos + 1));
525 } else {
526 name = trim(def);
527 value = "";
528 }
529
530 // Add to macros map (will override global if same name)
531 macros[name] = value;
532 predefined.insert(name);
533 }
534 }
535
536 //----------------------------------------------------------
537 // Helpers
538 //----------------------------------------------------------
539 std::string loadFile(const std::string & fname) {
540 std::ifstream f(fname);
541 if (!f.is_open()) {
542 throw std::runtime_error("Could not open file: " + fname);
543 }
544 std::stringstream ss;
545 ss << f.rdbuf();
546 return ss.str();
547 }
548
549 bool condActive(const std::vector<Cond> & cond) const {
550 if (cond.empty()) {
551 return true;
552 }
553 return cond.back().active;
554 }
555
556 //----------------------------------------------------------
557 // Process a file
558 //----------------------------------------------------------
559 std::string processFile(const std::string & name,
560 std::unordered_map<std::string, std::string> & macros,
561 const std::unordered_set<std::string> & predefined_macros,
562 std::unordered_set<std::string> & include_stack,
563 DirectiveMode mode) {
564 if (include_stack.count(name)) {
565 throw std::runtime_error("Recursive include: " + name);
566 }
567
568 include_stack.insert(name);
569 std::string shader_code = loadFile(name);
570 std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode);
571 include_stack.erase(name);
572 return out;
573 }
574
575 std::string processIncludeFile(const std::string & fname,
576 std::unordered_map<std::string, std::string> & macros,
577 const std::unordered_set<std::string> & predefined_macros,
578 std::unordered_set<std::string> & include_stack,
579 DirectiveMode mode) {
580 std::string full_path = opts_.include_path + "/" + fname;
581 return processFile(full_path, macros, predefined_macros, include_stack, mode);
582 }
583
584 //----------------------------------------------------------
585 // Process text
586 //----------------------------------------------------------
587 std::string processString(const std::string & shader_code,
588 std::unordered_map<std::string, std::string> & macros,
589 const std::unordered_set<std::string> & predefined_macros,
590 std::unordered_set<std::string> & include_stack,
591 DirectiveMode mode) {
592 std::vector<Cond> cond; // Conditional stack for this shader
593 std::stringstream out;
594 std::istringstream in(shader_code);
595 std::string line;
596
597 while (std::getline(in, line)) {
598 std::string t = trim(line);
599
600 if (!t.empty() && t[0] == '#') {
601 bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
602 if (mode == DirectiveMode::IncludesOnly && !handled) {
603 out << line << "\n";
604 }
605 } else {
606 if (mode == DirectiveMode::IncludesOnly) {
607 out << line << "\n";
608 } else if (condActive(cond)) {
609 // Expand macros in the line before outputting
610 std::string expanded = expandMacrosRecursive(line, macros);
611 out << expanded << "\n";
612 }
613 }
614 }
615
616 if (mode == DirectiveMode::All && !cond.empty()) {
617 throw std::runtime_error("Unclosed #if directive");
618 }
619
620 return out.str();
621 }
622
623 //----------------------------------------------------------
624 // Directive handler
625 //----------------------------------------------------------
626 bool handleDirective(const std::string & t,
627 std::stringstream & out,
628 std::unordered_map<std::string, std::string> & macros,
629 const std::unordered_set<std::string> & predefined_macros,
630 std::vector<Cond> & cond,
631 std::unordered_set<std::string> & include_stack,
632 DirectiveMode mode) {
633 // split into tokens
634 std::string body = t.substr(1);
635 std::istringstream iss(body);
636 std::string cmd;
637 iss >> cmd;
638
639 if (cmd == "include") {
640 if (mode == DirectiveMode::All && !condActive(cond)) {
641 return true;
642 }
643 std::string file;
644 iss >> file;
645 if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
646 file = file.substr(1, file.size() - 2);
647 }
648 out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
649 return true;
650 }
651
652 if (mode == DirectiveMode::IncludesOnly) {
653 return false;
654 }
655
656 if (cmd == "define") {
657 if (!condActive(cond)) {
658 return true;
659 }
660 std::string name;
661 iss >> name;
662 // Don't override predefined macros from options
663 if (predefined_macros.count(name)) {
664 return true;
665 }
666 std::string value = trim_value(iss);
667 macros[name] = value;
668 return true;
669 }
670
671 if (cmd == "undef") {
672 if (!condActive(cond)) {
673 return true;
674 }
675 std::string name;
676 iss >> name;
677 // Don't undef predefined macros from options
678 if (predefined_macros.count(name)) {
679 return true;
680 }
681 macros.erase(name);
682 return true;
683 }
684
685 if (cmd == "ifdef") {
686 std::string name;
687 iss >> name;
688 bool p = condActive(cond);
689 bool v = macros.count(name);
690 cond.push_back({ p, p && v, p && v });
691 return true;
692 }
693
694 if (cmd == "ifndef") {
695 std::string name;
696 iss >> name;
697 bool p = condActive(cond);
698 bool v = !macros.count(name);
699 cond.push_back({ p, p && v, p && v });
700 return true;
701 }
702
703 if (cmd == "if") {
704 std::string expr = trim_value(iss);
705 bool p = condActive(cond);
706 bool v = false;
707 if (p) {
708 std::unordered_set<std::string> visiting;
709 ExprParser ep(expr, macros, visiting);
710 v = ep.parse() != 0;
711 }
712 cond.push_back({ p, p && v, p && v });
713 return true;
714 }
715
716 if (cmd == "elif") {
717 std::string expr = trim_value(iss);
718
719 if (cond.empty()) {
720 throw std::runtime_error("#elif without #if");
721 }
722
723 Cond & c = cond.back();
724 if (!c.parent_active) {
725 c.active = false;
726 return true;
727 }
728
729 if (c.taken) {
730 c.active = false;
731 return true;
732 }
733
734 std::unordered_set<std::string> visiting;
735 ExprParser ep(expr, macros, visiting);
736 bool v = ep.parse() != 0;
737 c.active = v;
738 if (v) {
739 c.taken = true;
740 }
741 return true;
742 }
743
744 if (cmd == "else") {
745 if (cond.empty()) {
746 throw std::runtime_error("#else without #if");
747 }
748
749 Cond & c = cond.back();
750 if (!c.parent_active) {
751 c.active = false;
752 return true;
753 }
754 if (c.taken) {
755 c.active = false;
756 } else {
757 c.active = true;
758 c.taken = true;
759 }
760 return true;
761 }
762
763 if (cmd == "endif") {
764 if (cond.empty()) {
765 throw std::runtime_error("#endif without #if");
766 }
767 cond.pop_back();
768 return true;
769 }
770
771 // Unknown directive
772 throw std::runtime_error("Unknown directive: #" + cmd);
773 }
774};
775
776} // namespace pre_wgsl
777
778#endif // PRE_WGSL_HPP