1#ifndef PRE_WGSL_HPP
  2#define PRE_WGSL_HPP
  3
  4#include <cctype>
  5#include <fstream>
  6#include <sstream>
  7#include <stdexcept>
  8#include <string>
  9#include <string_view>
 10#include <unordered_map>
 11#include <unordered_set>
 12#include <vector>
 13
 14namespace pre_wgsl {
 15
 16//==============================================================
 17// Options
 18//==============================================================
 19struct Options {
 20    std::string              include_path = ".";
 21    std::vector<std::string> macros;
 22};
 23
 24//==============================================================
 25// Utility: trim
 26//==============================================================
 27static std::string trim(const std::string & s) {
 28    size_t a = 0;
 29    while (a < s.size() && std::isspace((unsigned char) s[a])) {
 30        a++;
 31    }
 32    size_t b = s.size();
 33    while (b > a && std::isspace((unsigned char) s[b - 1])) {
 34        b--;
 35    }
 36    return s.substr(a, b - a);
 37}
 38
 39static std::string trim_value(std::istream & is) {
 40    std::string str;
 41    std::getline(is, str);
 42    return trim(str);
 43}
 44
 45static bool isIdentChar(char c) {
 46    return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
 47}
 48
 49static std::string expandMacrosRecursiveInternal(const std::string &                                  line,
 50                                                 const std::unordered_map<std::string, std::string> & macros,
 51                                                 std::unordered_set<std::string> &                    visiting);
 52
 53static std::string expandMacroValue(const std::string &                                  name,
 54                                    const std::unordered_map<std::string, std::string> & macros,
 55                                    std::unordered_set<std::string> &                    visiting) {
 56    if (visiting.count(name)) {
 57        throw std::runtime_error("Recursive macro: " + name);
 58    }
 59    visiting.insert(name);
 60
 61    auto it = macros.find(name);
 62    if (it == macros.end()) {
 63        visiting.erase(name);
 64        return name;
 65    }
 66
 67    const std::string & value = it->second;
 68    if (value.empty()) {
 69        visiting.erase(name);
 70        return "";
 71    }
 72
 73    std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
 74    visiting.erase(name);
 75    return expanded;
 76}
 77
 78static std::string expandMacrosRecursiveInternal(const std::string &                                  line,
 79                                                 const std::unordered_map<std::string, std::string> & macros,
 80                                                 std::unordered_set<std::string> &                    visiting) {
 81    std::string result;
 82    result.reserve(line.size());
 83
 84    size_t i = 0;
 85    while (i < line.size()) {
 86        if (isIdentChar(line[i])) {
 87            size_t start = i;
 88            while (i < line.size() && isIdentChar(line[i])) {
 89                i++;
 90            }
 91            std::string token = line.substr(start, i - start);
 92
 93            auto it = macros.find(token);
 94            if (it != macros.end()) {
 95                result += expandMacroValue(token, macros, visiting);
 96            } else {
 97                result += token;
 98            }
 99        } else {
100            result += line[i];
101            i++;
102        }
103    }
104
105    return result;
106}
107
108static std::string expandMacrosRecursive(const std::string &                                  line,
109                                         const std::unordered_map<std::string, std::string> & macros) {
110    std::unordered_set<std::string> visiting;
111    return expandMacrosRecursiveInternal(line, macros, visiting);
112}
113
114//==============================================================
115// Tokenizer for expressions in #if/#elif
116//==============================================================
117class ExprLexer {
118  public:
119    enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
120
121    struct Tok {
122        Kind        kind;
123        std::string text;
124    };
125
126    explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
127
128    Tok next() {
129        skipWS();
130        if (pos >= src.size()) {
131            return { END, "" };
132        }
133
134        char c = src[pos];
135
136        // number
137        if (std::isdigit((unsigned char) c)) {
138            size_t start = pos;
139            while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
140                pos++;
141            }
142            return { NUMBER, std::string(src.substr(start, pos - start)) };
143        }
144
145        // identifier
146        if (std::isalpha((unsigned char) c) || c == '_') {
147            size_t start = pos;
148            while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
149                pos++;
150            }
151            return { IDENT, std::string(src.substr(start, pos - start)) };
152        }
153
154        if (c == '(') {
155            pos++;
156            return { LPAREN, "(" };
157        }
158        if (c == ')') {
159            pos++;
160            return { RPAREN, ")" };
161        }
162
163        // multi-char operators
164        static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
165        for (auto op : two_ops) {
166            if (src.substr(pos, 2) == op) {
167                pos += 2;
168                return { OP, std::string(op) };
169            }
170        }
171
172        // single-char operators
173        if (std::string("+-*/%<>!").find(c) != std::string::npos) {
174            pos++;
175            return { OP, std::string(1, c) };
176        }
177
178        // unexpected
179        pos++;
180        return { END, "" };
181    }
182
183  private:
184    std::string_view src;
185    size_t           pos;
186
187    void skipWS() {
188        while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
189            pos++;
190        }
191    }
192};
193
194//==============================================================
195// Expression Parser (recursive descent)
196//==============================================================
197class ExprParser {
198  public:
199    ExprParser(std::string_view                                     expr,
200               const std::unordered_map<std::string, std::string> & macros,
201               std::unordered_set<std::string> &                    visiting) :
202        lex(expr),
203        macros(macros),
204        visiting(visiting) {
205        advance();
206    }
207
208    int parse() { return parseLogicalOr(); }
209
210  private:
211    ExprLexer                                            lex;
212    ExprLexer::Tok                                       tok;
213    const std::unordered_map<std::string, std::string> & macros;
214    std::unordered_set<std::string> &                    visiting;
215
216    void advance() { tok = lex.next(); }
217
218    bool acceptOp(const std::string & s) {
219        if (tok.kind == ExprLexer::OP && tok.text == s) {
220            advance();
221            return true;
222        }
223        return false;
224    }
225
226    bool acceptKind(ExprLexer::Kind k) {
227        if (tok.kind == k) {
228            advance();
229            return true;
230        }
231        return false;
232    }
233
234    int parseLogicalOr() {
235        int v = parseLogicalAnd();
236        while (acceptOp("||")) {
237            int rhs = parseLogicalAnd();
238            v       = (v || rhs);
239        }
240        return v;
241    }
242
243    int parseLogicalAnd() {
244        int v = parseEquality();
245        while (acceptOp("&&")) {
246            int rhs = parseEquality();
247            v       = (v && rhs);
248        }
249        return v;
250    }
251
252    int parseEquality() {
253        int v = parseRelational();
254        for (;;) {
255            if (acceptOp("==")) {
256                int rhs = parseRelational();
257                v       = (v == rhs);
258            } else if (acceptOp("!=")) {
259                int rhs = parseRelational();
260                v       = (v != rhs);
261            } else {
262                break;
263            }
264        }
265        return v;
266    }
267
268    int parseRelational() {
269        int v = parseShift();
270        for (;;) {
271            if (acceptOp("<")) {
272                int rhs = parseShift();
273                v       = (v < rhs);
274            } else if (acceptOp(">")) {
275                int rhs = parseShift();
276                v       = (v > rhs);
277            } else if (acceptOp("<=")) {
278                int rhs = parseShift();
279                v       = (v <= rhs);
280            } else if (acceptOp(">=")) {
281                int rhs = parseShift();
282                v       = (v >= rhs);
283            } else {
284                break;
285            }
286        }
287        return v;
288    }
289
290    int parseShift() {
291        int v = parseAdd();
292        for (;;) {
293            if (acceptOp("<<")) {
294                int rhs = parseAdd();
295                v       = (v << rhs);
296            } else if (acceptOp(">>")) {
297                int rhs = parseAdd();
298                v       = (v >> rhs);
299            } else {
300                break;
301            }
302        }
303        return v;
304    }
305
306    int parseAdd() {
307        int v = parseMult();
308        for (;;) {
309            if (acceptOp("+")) {
310                int rhs = parseMult();
311                v       = (v + rhs);
312            } else if (acceptOp("-")) {
313                int rhs = parseMult();
314                v       = (v - rhs);
315            } else {
316                break;
317            }
318        }
319        return v;
320    }
321
322    int parseMult() {
323        int v = parseUnary();
324        for (;;) {
325            if (acceptOp("*")) {
326                int rhs = parseUnary();
327                v       = (v * rhs);
328            } else if (acceptOp("/")) {
329                int rhs = parseUnary();
330                v       = (rhs == 0 ? 0 : v / rhs);
331            } else if (acceptOp("%")) {
332                int rhs = parseUnary();
333                v       = (rhs == 0 ? 0 : v % rhs);
334            } else {
335                break;
336            }
337        }
338        return v;
339    }
340
341    int parseUnary() {
342        if (acceptOp("!")) {
343            return !parseUnary();
344        }
345        if (acceptOp("-")) {
346            return -parseUnary();
347        }
348        if (acceptOp("+")) {
349            return +parseUnary();
350        }
351        return parsePrimary();
352    }
353
354    int parsePrimary() {
355        // '(' expr ')'
356        if (acceptKind(ExprLexer::LPAREN)) {
357            int v = parse();
358            if (!acceptKind(ExprLexer::RPAREN)) {
359                throw std::runtime_error("missing ')'");
360            }
361            return v;
362        }
363
364        // number
365        if (tok.kind == ExprLexer::NUMBER) {
366            int v = std::stoi(tok.text);
367            advance();
368            return v;
369        }
370
371        // defined(identifier)
372        if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
373            advance();
374            if (acceptKind(ExprLexer::LPAREN)) {
375                if (tok.kind != ExprLexer::IDENT) {
376                    throw std::runtime_error("expected identifier in defined()");
377                }
378                std::string name = tok.text;
379                advance();
380                if (!acceptKind(ExprLexer::RPAREN)) {
381                    throw std::runtime_error("missing ) in defined()");
382                }
383                return macros.count(name) ? 1 : 0;
384            } else {
385                // defined NAME
386                if (tok.kind != ExprLexer::IDENT) {
387                    throw std::runtime_error("expected identifier in defined NAME");
388                }
389                std::string name = tok.text;
390                advance();
391                return macros.count(name) ? 1 : 0;
392            }
393        }
394
395        // identifier -> treat as integer, if defined use its value else 0
396        if (tok.kind == ExprLexer::IDENT) {
397            std::string name = tok.text;
398            advance();
399            auto it = macros.find(name);
400            if (it == macros.end()) {
401                return 0;
402            }
403            if (it->second.empty()) {
404                return 1;
405            }
406            return evalMacroExpression(name, it->second);
407        }
408
409        // unexpected
410        return 0;
411    }
412
413    int evalMacroExpression(const std::string & name, const std::string & value) {
414        if (visiting.count(name)) {
415            throw std::runtime_error("Recursive macro: " + name);
416        }
417
418        visiting.insert(name);
419        ExprParser ep(value, macros, visiting);
420        int        v = ep.parse();
421        visiting.erase(name);
422        return v;
423    }
424};
425
426//==============================================================
427// Preprocessor
428//==============================================================
429class Preprocessor {
430  public:
431    explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
432        // Treat empty include path as current directory
433        if (opts_.include_path.empty()) {
434            opts_.include_path = ".";
435        }
436        parseMacroDefinitions(opts_.macros);
437    }
438
439    std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
440        std::unordered_map<std::string, std::string> macros;
441        std::unordered_set<std::string>              predefined;
442        std::unordered_set<std::string>              include_stack;
443        buildMacros(additional_macros, macros, predefined);
444
445        std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
446        return result;
447    }
448
449    std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
450        std::unordered_map<std::string, std::string> macros;
451        std::unordered_set<std::string>              predefined;
452        std::unordered_set<std::string>              include_stack;
453        buildMacros(additional_macros, macros, predefined);
454
455        std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
456        return result;
457    }
458
459    std::string preprocess_includes_file(const std::string & filename) {
460        std::unordered_map<std::string, std::string> macros;
461        std::unordered_set<std::string>              predefined;
462        std::unordered_set<std::string>              include_stack;
463        std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
464        return result;
465    }
466
467    std::string preprocess_includes(const std::string & contents) {
468        std::unordered_map<std::string, std::string> macros;
469        std::unordered_set<std::string>              predefined;
470        std::unordered_set<std::string>              include_stack;
471        std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
472        return result;
473    }
474
475  private:
476    Options                                      opts_;
477    std::unordered_map<std::string, std::string> global_macros;
478
479    enum class DirectiveMode { All, IncludesOnly };
480
481    struct Cond {
482        bool parent_active;
483        bool active;
484        bool taken;
485    };
486
487    //----------------------------------------------------------
488    // Parse macro definitions into global_macros
489    //----------------------------------------------------------
490    void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
491        for (const auto & def : macro_defs) {
492            size_t eq_pos = def.find('=');
493            if (eq_pos != std::string::npos) {
494                // Format: NAME=VALUE
495                std::string name    = trim(def.substr(0, eq_pos));
496                std::string value   = trim(def.substr(eq_pos + 1));
497                global_macros[name] = value;
498            } else {
499                // Format: NAME
500                std::string name    = trim(def);
501                global_macros[name] = "";
502            }
503        }
504    }
505
506    //----------------------------------------------------------
507    // Build combined macro map and predefined set for a preprocessing operation
508    //----------------------------------------------------------
509    void buildMacros(const std::vector<std::string> &               additional_macros,
510                     std::unordered_map<std::string, std::string> & macros,
511                     std::unordered_set<std::string> &              predefined) {
512        macros = global_macros;
513        predefined.clear();
514
515        for (const auto & [name, value] : global_macros) {
516            predefined.insert(name);
517        }
518
519        for (const auto & def : additional_macros) {
520            size_t      eq_pos = def.find('=');
521            std::string name, value;
522            if (eq_pos != std::string::npos) {
523                name  = trim(def.substr(0, eq_pos));
524                value = trim(def.substr(eq_pos + 1));
525            } else {
526                name  = trim(def);
527                value = "";
528            }
529
530            // Add to macros map (will override global if same name)
531            macros[name] = value;
532            predefined.insert(name);
533        }
534    }
535
536    //----------------------------------------------------------
537    // Helpers
538    //----------------------------------------------------------
539    std::string loadFile(const std::string & fname) {
540        std::ifstream f(fname);
541        if (!f.is_open()) {
542            throw std::runtime_error("Could not open file: " + fname);
543        }
544        std::stringstream ss;
545        ss << f.rdbuf();
546        return ss.str();
547    }
548
549    bool condActive(const std::vector<Cond> & cond) const {
550        if (cond.empty()) {
551            return true;
552        }
553        return cond.back().active;
554    }
555
556    //----------------------------------------------------------
557    // Process a file
558    //----------------------------------------------------------
559    std::string processFile(const std::string &                            name,
560                            std::unordered_map<std::string, std::string> & macros,
561                            const std::unordered_set<std::string> &        predefined_macros,
562                            std::unordered_set<std::string> &              include_stack,
563                            DirectiveMode                                  mode) {
564        if (include_stack.count(name)) {
565            throw std::runtime_error("Recursive include: " + name);
566        }
567
568        include_stack.insert(name);
569        std::string shader_code = loadFile(name);
570        std::string out         = processString(shader_code, macros, predefined_macros, include_stack, mode);
571        include_stack.erase(name);
572        return out;
573    }
574
575    std::string processIncludeFile(const std::string &                            fname,
576                                   std::unordered_map<std::string, std::string> & macros,
577                                   const std::unordered_set<std::string> &        predefined_macros,
578                                   std::unordered_set<std::string> &              include_stack,
579                                   DirectiveMode                                  mode) {
580        std::string full_path = opts_.include_path + "/" + fname;
581        return processFile(full_path, macros, predefined_macros, include_stack, mode);
582    }
583
584    //----------------------------------------------------------
585    // Process text
586    //----------------------------------------------------------
587    std::string processString(const std::string &                            shader_code,
588                              std::unordered_map<std::string, std::string> & macros,
589                              const std::unordered_set<std::string> &        predefined_macros,
590                              std::unordered_set<std::string> &              include_stack,
591                              DirectiveMode                                  mode) {
592        std::vector<Cond>  cond;  // Conditional stack for this shader
593        std::stringstream  out;
594        std::istringstream in(shader_code);
595        std::string        line;
596
597        while (std::getline(in, line)) {
598            std::string t = trim(line);
599
600            if (!t.empty() && t[0] == '#') {
601                bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
602                if (mode == DirectiveMode::IncludesOnly && !handled) {
603                    out << line << "\n";
604                }
605            } else {
606                if (mode == DirectiveMode::IncludesOnly) {
607                    out << line << "\n";
608                } else if (condActive(cond)) {
609                    // Expand macros in the line before outputting
610                    std::string expanded = expandMacrosRecursive(line, macros);
611                    out << expanded << "\n";
612                }
613            }
614        }
615
616        if (mode == DirectiveMode::All && !cond.empty()) {
617            throw std::runtime_error("Unclosed #if directive");
618        }
619
620        return out.str();
621    }
622
623    //----------------------------------------------------------
624    // Directive handler
625    //----------------------------------------------------------
626    bool handleDirective(const std::string &                            t,
627                         std::stringstream &                            out,
628                         std::unordered_map<std::string, std::string> & macros,
629                         const std::unordered_set<std::string> &        predefined_macros,
630                         std::vector<Cond> &                            cond,
631                         std::unordered_set<std::string> &              include_stack,
632                         DirectiveMode                                  mode) {
633        // split into tokens
634        std::string        body = t.substr(1);
635        std::istringstream iss(body);
636        std::string        cmd;
637        iss >> cmd;
638
639        if (cmd == "include") {
640            if (mode == DirectiveMode::All && !condActive(cond)) {
641                return true;
642            }
643            std::string file;
644            iss >> file;
645            if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
646                file = file.substr(1, file.size() - 2);
647            }
648            out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
649            return true;
650        }
651
652        if (mode == DirectiveMode::IncludesOnly) {
653            return false;
654        }
655
656        if (cmd == "define") {
657            if (!condActive(cond)) {
658                return true;
659            }
660            std::string name;
661            iss >> name;
662            // Don't override predefined macros from options
663            if (predefined_macros.count(name)) {
664                return true;
665            }
666            std::string value = trim_value(iss);
667            macros[name]      = value;
668            return true;
669        }
670
671        if (cmd == "undef") {
672            if (!condActive(cond)) {
673                return true;
674            }
675            std::string name;
676            iss >> name;
677            // Don't undef predefined macros from options
678            if (predefined_macros.count(name)) {
679                return true;
680            }
681            macros.erase(name);
682            return true;
683        }
684
685        if (cmd == "ifdef") {
686            std::string name;
687            iss >> name;
688            bool p = condActive(cond);
689            bool v = macros.count(name);
690            cond.push_back({ p, p && v, p && v });
691            return true;
692        }
693
694        if (cmd == "ifndef") {
695            std::string name;
696            iss >> name;
697            bool p = condActive(cond);
698            bool v = !macros.count(name);
699            cond.push_back({ p, p && v, p && v });
700            return true;
701        }
702
703        if (cmd == "if") {
704            std::string expr = trim_value(iss);
705            bool        p    = condActive(cond);
706            bool        v    = false;
707            if (p) {
708                std::unordered_set<std::string> visiting;
709                ExprParser                      ep(expr, macros, visiting);
710                v = ep.parse() != 0;
711            }
712            cond.push_back({ p, p && v, p && v });
713            return true;
714        }
715
716        if (cmd == "elif") {
717            std::string expr = trim_value(iss);
718
719            if (cond.empty()) {
720                throw std::runtime_error("#elif without #if");
721            }
722
723            Cond & c = cond.back();
724            if (!c.parent_active) {
725                c.active = false;
726                return true;
727            }
728
729            if (c.taken) {
730                c.active = false;
731                return true;
732            }
733
734            std::unordered_set<std::string> visiting;
735            ExprParser                      ep(expr, macros, visiting);
736            bool                            v = ep.parse() != 0;
737            c.active                          = v;
738            if (v) {
739                c.taken = true;
740            }
741            return true;
742        }
743
744        if (cmd == "else") {
745            if (cond.empty()) {
746                throw std::runtime_error("#else without #if");
747            }
748
749            Cond & c = cond.back();
750            if (!c.parent_active) {
751                c.active = false;
752                return true;
753            }
754            if (c.taken) {
755                c.active = false;
756            } else {
757                c.active = true;
758                c.taken  = true;
759            }
760            return true;
761        }
762
763        if (cmd == "endif") {
764            if (cond.empty()) {
765                throw std::runtime_error("#endif without #if");
766            }
767            cond.pop_back();
768            return true;
769        }
770
771        // Unknown directive
772        throw std::runtime_error("Unknown directive: #" + cmd);
773    }
774};
775
776}  // namespace pre_wgsl
777
778#endif  // PRE_WGSL_HPP