summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp')
-rw-r--r--llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp778
1 files changed, 778 insertions, 0 deletions
diff --git a/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp b/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp
new file mode 100644
index 0000000..4d43594
--- /dev/null
+++ b/llama.cpp/ggml/src/ggml-webgpu/pre_wgsl.hpp
@@ -0,0 +1,778 @@
+#ifndef PRE_WGSL_HPP
+#define PRE_WGSL_HPP
+
+#include <cctype>
+#include <fstream>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace pre_wgsl {
+
+//==============================================================
+// Options
+//==============================================================
+struct Options {
+ std::string include_path = ".";
+ std::vector<std::string> macros;
+};
+
+//==============================================================
+// Utility: trim
+//==============================================================
+static std::string trim(const std::string & s) {
+ size_t a = 0;
+ while (a < s.size() && std::isspace((unsigned char) s[a])) {
+ a++;
+ }
+ size_t b = s.size();
+ while (b > a && std::isspace((unsigned char) s[b - 1])) {
+ b--;
+ }
+ return s.substr(a, b - a);
+}
+
+static std::string trim_value(std::istream & is) {
+ std::string str;
+ std::getline(is, str);
+ return trim(str);
+}
+
+static bool isIdentChar(char c) {
+ return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
+}
+
+static std::string expandMacrosRecursiveInternal(const std::string & line,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting);
+
+static std::string expandMacroValue(const std::string & name,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting) {
+ if (visiting.count(name)) {
+ throw std::runtime_error("Recursive macro: " + name);
+ }
+ visiting.insert(name);
+
+ auto it = macros.find(name);
+ if (it == macros.end()) {
+ visiting.erase(name);
+ return name;
+ }
+
+ const std::string & value = it->second;
+ if (value.empty()) {
+ visiting.erase(name);
+ return "";
+ }
+
+ std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
+ visiting.erase(name);
+ return expanded;
+}
+
+static std::string expandMacrosRecursiveInternal(const std::string & line,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting) {
+ std::string result;
+ result.reserve(line.size());
+
+ size_t i = 0;
+ while (i < line.size()) {
+ if (isIdentChar(line[i])) {
+ size_t start = i;
+ while (i < line.size() && isIdentChar(line[i])) {
+ i++;
+ }
+ std::string token = line.substr(start, i - start);
+
+ auto it = macros.find(token);
+ if (it != macros.end()) {
+ result += expandMacroValue(token, macros, visiting);
+ } else {
+ result += token;
+ }
+ } else {
+ result += line[i];
+ i++;
+ }
+ }
+
+ return result;
+}
+
+static std::string expandMacrosRecursive(const std::string & line,
+ const std::unordered_map<std::string, std::string> & macros) {
+ std::unordered_set<std::string> visiting;
+ return expandMacrosRecursiveInternal(line, macros, visiting);
+}
+
+//==============================================================
+// Tokenizer for expressions in #if/#elif
+//==============================================================
+class ExprLexer {
+ public:
+ enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
+
+ struct Tok {
+ Kind kind;
+ std::string text;
+ };
+
+ explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
+
+ Tok next() {
+ skipWS();
+ if (pos >= src.size()) {
+ return { END, "" };
+ }
+
+ char c = src[pos];
+
+ // number
+ if (std::isdigit((unsigned char) c)) {
+ size_t start = pos;
+ while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
+ pos++;
+ }
+ return { NUMBER, std::string(src.substr(start, pos - start)) };
+ }
+
+ // identifier
+ if (std::isalpha((unsigned char) c) || c == '_') {
+ size_t start = pos;
+ while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
+ pos++;
+ }
+ return { IDENT, std::string(src.substr(start, pos - start)) };
+ }
+
+ if (c == '(') {
+ pos++;
+ return { LPAREN, "(" };
+ }
+ if (c == ')') {
+ pos++;
+ return { RPAREN, ")" };
+ }
+
+ // multi-char operators
+ static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
+ for (auto op : two_ops) {
+ if (src.substr(pos, 2) == op) {
+ pos += 2;
+ return { OP, std::string(op) };
+ }
+ }
+
+ // single-char operators
+ if (std::string("+-*/%<>!").find(c) != std::string::npos) {
+ pos++;
+ return { OP, std::string(1, c) };
+ }
+
+ // unexpected
+ pos++;
+ return { END, "" };
+ }
+
+ private:
+ std::string_view src;
+ size_t pos;
+
+ void skipWS() {
+ while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
+ pos++;
+ }
+ }
+};
+
+//==============================================================
+// Expression Parser (recursive descent)
+//==============================================================
+class ExprParser {
+ public:
+ ExprParser(std::string_view expr,
+ const std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & visiting) :
+ lex(expr),
+ macros(macros),
+ visiting(visiting) {
+ advance();
+ }
+
+ int parse() { return parseLogicalOr(); }
+
+ private:
+ ExprLexer lex;
+ ExprLexer::Tok tok;
+ const std::unordered_map<std::string, std::string> & macros;
+ std::unordered_set<std::string> & visiting;
+
+ void advance() { tok = lex.next(); }
+
+ bool acceptOp(const std::string & s) {
+ if (tok.kind == ExprLexer::OP && tok.text == s) {
+ advance();
+ return true;
+ }
+ return false;
+ }
+
+ bool acceptKind(ExprLexer::Kind k) {
+ if (tok.kind == k) {
+ advance();
+ return true;
+ }
+ return false;
+ }
+
+ int parseLogicalOr() {
+ int v = parseLogicalAnd();
+ while (acceptOp("||")) {
+ int rhs = parseLogicalAnd();
+ v = (v || rhs);
+ }
+ return v;
+ }
+
+ int parseLogicalAnd() {
+ int v = parseEquality();
+ while (acceptOp("&&")) {
+ int rhs = parseEquality();
+ v = (v && rhs);
+ }
+ return v;
+ }
+
+ int parseEquality() {
+ int v = parseRelational();
+ for (;;) {
+ if (acceptOp("==")) {
+ int rhs = parseRelational();
+ v = (v == rhs);
+ } else if (acceptOp("!=")) {
+ int rhs = parseRelational();
+ v = (v != rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseRelational() {
+ int v = parseShift();
+ for (;;) {
+ if (acceptOp("<")) {
+ int rhs = parseShift();
+ v = (v < rhs);
+ } else if (acceptOp(">")) {
+ int rhs = parseShift();
+ v = (v > rhs);
+ } else if (acceptOp("<=")) {
+ int rhs = parseShift();
+ v = (v <= rhs);
+ } else if (acceptOp(">=")) {
+ int rhs = parseShift();
+ v = (v >= rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseShift() {
+ int v = parseAdd();
+ for (;;) {
+ if (acceptOp("<<")) {
+ int rhs = parseAdd();
+ v = (v << rhs);
+ } else if (acceptOp(">>")) {
+ int rhs = parseAdd();
+ v = (v >> rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseAdd() {
+ int v = parseMult();
+ for (;;) {
+ if (acceptOp("+")) {
+ int rhs = parseMult();
+ v = (v + rhs);
+ } else if (acceptOp("-")) {
+ int rhs = parseMult();
+ v = (v - rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseMult() {
+ int v = parseUnary();
+ for (;;) {
+ if (acceptOp("*")) {
+ int rhs = parseUnary();
+ v = (v * rhs);
+ } else if (acceptOp("/")) {
+ int rhs = parseUnary();
+ v = (rhs == 0 ? 0 : v / rhs);
+ } else if (acceptOp("%")) {
+ int rhs = parseUnary();
+ v = (rhs == 0 ? 0 : v % rhs);
+ } else {
+ break;
+ }
+ }
+ return v;
+ }
+
+ int parseUnary() {
+ if (acceptOp("!")) {
+ return !parseUnary();
+ }
+ if (acceptOp("-")) {
+ return -parseUnary();
+ }
+ if (acceptOp("+")) {
+ return +parseUnary();
+ }
+ return parsePrimary();
+ }
+
+ int parsePrimary() {
+ // '(' expr ')'
+ if (acceptKind(ExprLexer::LPAREN)) {
+ int v = parse();
+ if (!acceptKind(ExprLexer::RPAREN)) {
+ throw std::runtime_error("missing ')'");
+ }
+ return v;
+ }
+
+ // number
+ if (tok.kind == ExprLexer::NUMBER) {
+ int v = std::stoi(tok.text);
+ advance();
+ return v;
+ }
+
+ // defined(identifier)
+ if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
+ advance();
+ if (acceptKind(ExprLexer::LPAREN)) {
+ if (tok.kind != ExprLexer::IDENT) {
+ throw std::runtime_error("expected identifier in defined()");
+ }
+ std::string name = tok.text;
+ advance();
+ if (!acceptKind(ExprLexer::RPAREN)) {
+ throw std::runtime_error("missing ) in defined()");
+ }
+ return macros.count(name) ? 1 : 0;
+ } else {
+ // defined NAME
+ if (tok.kind != ExprLexer::IDENT) {
+ throw std::runtime_error("expected identifier in defined NAME");
+ }
+ std::string name = tok.text;
+ advance();
+ return macros.count(name) ? 1 : 0;
+ }
+ }
+
+ // identifier -> treat as integer, if defined use its value else 0
+ if (tok.kind == ExprLexer::IDENT) {
+ std::string name = tok.text;
+ advance();
+ auto it = macros.find(name);
+ if (it == macros.end()) {
+ return 0;
+ }
+ if (it->second.empty()) {
+ return 1;
+ }
+ return evalMacroExpression(name, it->second);
+ }
+
+ // unexpected
+ return 0;
+ }
+
+ int evalMacroExpression(const std::string & name, const std::string & value) {
+ if (visiting.count(name)) {
+ throw std::runtime_error("Recursive macro: " + name);
+ }
+
+ visiting.insert(name);
+ ExprParser ep(value, macros, visiting);
+ int v = ep.parse();
+ visiting.erase(name);
+ return v;
+ }
+};
+
+//==============================================================
+// Preprocessor
+//==============================================================
+class Preprocessor {
+ public:
+ explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
+ // Treat empty include path as current directory
+ if (opts_.include_path.empty()) {
+ opts_.include_path = ".";
+ }
+ parseMacroDefinitions(opts_.macros);
+ }
+
+ std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ buildMacros(additional_macros, macros, predefined);
+
+ std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
+ return result;
+ }
+
+ std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ buildMacros(additional_macros, macros, predefined);
+
+ std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
+ return result;
+ }
+
+ std::string preprocess_includes_file(const std::string & filename) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
+ return result;
+ }
+
+ std::string preprocess_includes(const std::string & contents) {
+ std::unordered_map<std::string, std::string> macros;
+ std::unordered_set<std::string> predefined;
+ std::unordered_set<std::string> include_stack;
+ std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
+ return result;
+ }
+
+ private:
+ Options opts_;
+ std::unordered_map<std::string, std::string> global_macros;
+
+ enum class DirectiveMode { All, IncludesOnly };
+
+ struct Cond {
+ bool parent_active;
+ bool active;
+ bool taken;
+ };
+
+ //----------------------------------------------------------
+ // Parse macro definitions into global_macros
+ //----------------------------------------------------------
+ void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
+ for (const auto & def : macro_defs) {
+ size_t eq_pos = def.find('=');
+ if (eq_pos != std::string::npos) {
+ // Format: NAME=VALUE
+ std::string name = trim(def.substr(0, eq_pos));
+ std::string value = trim(def.substr(eq_pos + 1));
+ global_macros[name] = value;
+ } else {
+ // Format: NAME
+ std::string name = trim(def);
+ global_macros[name] = "";
+ }
+ }
+ }
+
+ //----------------------------------------------------------
+ // Build combined macro map and predefined set for a preprocessing operation
+ //----------------------------------------------------------
+ void buildMacros(const std::vector<std::string> & additional_macros,
+ std::unordered_map<std::string, std::string> & macros,
+ std::unordered_set<std::string> & predefined) {
+ macros = global_macros;
+ predefined.clear();
+
+ for (const auto & [name, value] : global_macros) {
+ predefined.insert(name);
+ }
+
+ for (const auto & def : additional_macros) {
+ size_t eq_pos = def.find('=');
+ std::string name, value;
+ if (eq_pos != std::string::npos) {
+ name = trim(def.substr(0, eq_pos));
+ value = trim(def.substr(eq_pos + 1));
+ } else {
+ name = trim(def);
+ value = "";
+ }
+
+ // Add to macros map (will override global if same name)
+ macros[name] = value;
+ predefined.insert(name);
+ }
+ }
+
+ //----------------------------------------------------------
+ // Helpers
+ //----------------------------------------------------------
+ std::string loadFile(const std::string & fname) {
+ std::ifstream f(fname);
+ if (!f.is_open()) {
+ throw std::runtime_error("Could not open file: " + fname);
+ }
+ std::stringstream ss;
+ ss << f.rdbuf();
+ return ss.str();
+ }
+
+ bool condActive(const std::vector<Cond> & cond) const {
+ if (cond.empty()) {
+ return true;
+ }
+ return cond.back().active;
+ }
+
+ //----------------------------------------------------------
+ // Process a file
+ //----------------------------------------------------------
+ std::string processFile(const std::string & name,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ if (include_stack.count(name)) {
+ throw std::runtime_error("Recursive include: " + name);
+ }
+
+ include_stack.insert(name);
+ std::string shader_code = loadFile(name);
+ std::string out = processString(shader_code, macros, predefined_macros, include_stack, mode);
+ include_stack.erase(name);
+ return out;
+ }
+
+ std::string processIncludeFile(const std::string & fname,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ std::string full_path = opts_.include_path + "/" + fname;
+ return processFile(full_path, macros, predefined_macros, include_stack, mode);
+ }
+
+ //----------------------------------------------------------
+ // Process text
+ //----------------------------------------------------------
+ std::string processString(const std::string & shader_code,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ std::vector<Cond> cond; // Conditional stack for this shader
+ std::stringstream out;
+ std::istringstream in(shader_code);
+ std::string line;
+
+ while (std::getline(in, line)) {
+ std::string t = trim(line);
+
+ if (!t.empty() && t[0] == '#') {
+ bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
+ if (mode == DirectiveMode::IncludesOnly && !handled) {
+ out << line << "\n";
+ }
+ } else {
+ if (mode == DirectiveMode::IncludesOnly) {
+ out << line << "\n";
+ } else if (condActive(cond)) {
+ // Expand macros in the line before outputting
+ std::string expanded = expandMacrosRecursive(line, macros);
+ out << expanded << "\n";
+ }
+ }
+ }
+
+ if (mode == DirectiveMode::All && !cond.empty()) {
+ throw std::runtime_error("Unclosed #if directive");
+ }
+
+ return out.str();
+ }
+
+ //----------------------------------------------------------
+ // Directive handler
+ //----------------------------------------------------------
+ bool handleDirective(const std::string & t,
+ std::stringstream & out,
+ std::unordered_map<std::string, std::string> & macros,
+ const std::unordered_set<std::string> & predefined_macros,
+ std::vector<Cond> & cond,
+ std::unordered_set<std::string> & include_stack,
+ DirectiveMode mode) {
+ // split into tokens
+ std::string body = t.substr(1);
+ std::istringstream iss(body);
+ std::string cmd;
+ iss >> cmd;
+
+ if (cmd == "include") {
+ if (mode == DirectiveMode::All && !condActive(cond)) {
+ return true;
+ }
+ std::string file;
+ iss >> file;
+ if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
+ file = file.substr(1, file.size() - 2);
+ }
+ out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
+ return true;
+ }
+
+ if (mode == DirectiveMode::IncludesOnly) {
+ return false;
+ }
+
+ if (cmd == "define") {
+ if (!condActive(cond)) {
+ return true;
+ }
+ std::string name;
+ iss >> name;
+ // Don't override predefined macros from options
+ if (predefined_macros.count(name)) {
+ return true;
+ }
+ std::string value = trim_value(iss);
+ macros[name] = value;
+ return true;
+ }
+
+ if (cmd == "undef") {
+ if (!condActive(cond)) {
+ return true;
+ }
+ std::string name;
+ iss >> name;
+ // Don't undef predefined macros from options
+ if (predefined_macros.count(name)) {
+ return true;
+ }
+ macros.erase(name);
+ return true;
+ }
+
+ if (cmd == "ifdef") {
+ std::string name;
+ iss >> name;
+ bool p = condActive(cond);
+ bool v = macros.count(name);
+ cond.push_back({ p, p && v, p && v });
+ return true;
+ }
+
+ if (cmd == "ifndef") {
+ std::string name;
+ iss >> name;
+ bool p = condActive(cond);
+ bool v = !macros.count(name);
+ cond.push_back({ p, p && v, p && v });
+ return true;
+ }
+
+ if (cmd == "if") {
+ std::string expr = trim_value(iss);
+ bool p = condActive(cond);
+ bool v = false;
+ if (p) {
+ std::unordered_set<std::string> visiting;
+ ExprParser ep(expr, macros, visiting);
+ v = ep.parse() != 0;
+ }
+ cond.push_back({ p, p && v, p && v });
+ return true;
+ }
+
+ if (cmd == "elif") {
+ std::string expr = trim_value(iss);
+
+ if (cond.empty()) {
+ throw std::runtime_error("#elif without #if");
+ }
+
+ Cond & c = cond.back();
+ if (!c.parent_active) {
+ c.active = false;
+ return true;
+ }
+
+ if (c.taken) {
+ c.active = false;
+ return true;
+ }
+
+ std::unordered_set<std::string> visiting;
+ ExprParser ep(expr, macros, visiting);
+ bool v = ep.parse() != 0;
+ c.active = v;
+ if (v) {
+ c.taken = true;
+ }
+ return true;
+ }
+
+ if (cmd == "else") {
+ if (cond.empty()) {
+ throw std::runtime_error("#else without #if");
+ }
+
+ Cond & c = cond.back();
+ if (!c.parent_active) {
+ c.active = false;
+ return true;
+ }
+ if (c.taken) {
+ c.active = false;
+ } else {
+ c.active = true;
+ c.taken = true;
+ }
+ return true;
+ }
+
+ if (cmd == "endif") {
+ if (cond.empty()) {
+ throw std::runtime_error("#endif without #if");
+ }
+ cond.pop_back();
+ return true;
+ }
+
+ // Unknown directive
+ throw std::runtime_error("Unknown directive: #" + cmd);
+ }
+};
+
+} // namespace pre_wgsl
+
+#endif // PRE_WGSL_HPP