1import os
2import re
3import ast
4import argparse
5
6
7def extract_block(text, name):
8 pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
9 match = re.search(pattern, text, re.DOTALL)
10 if not match:
11 raise ValueError(f"Missing block: {name}")
12 return match.group(1).strip()
13
14
15def parse_decls(decls_text):
16 decls = {}
17 for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
18 decls[name.strip()] = code.strip()
19 return decls
20
21
22def replace_repl_placeholders(variant, template_map):
23 for repl, code in variant["REPLS"].items():
24 for key, val in template_map.items():
25 # Match "key" and avoid matching subsequences using by using \b
26 code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
27 variant["REPLS"][repl] = code
28 return variant
29
30
31def replace_placeholders(shader_text, replacements):
32 for key, val in replacements.items():
33 # Match {{KEY}} literally, where KEY is escaped
34 pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
35 shader_text = re.sub(pattern, str(val), shader_text)
36 return shader_text
37
38
39def expand_includes(shader, input_dir):
40 """
41 Replace #include "file" lines in the text with the contents of that file.
42 Searches for files relative to input_dir.
43 """
44 include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
45
46 def replacer(match):
47 fname = match.group(1)
48 file_path = os.path.join(input_dir, fname)
49 if not os.path.exists(file_path):
50 raise FileNotFoundError(f"Included file not found: {file_path}")
51 with open(file_path, "r", encoding="utf-8") as f:
52 included_code = f.read()
53 # Recursively expand includes inside the included file
54 return expand_includes(included_code, input_dir)
55
56 return include_pattern.sub(replacer, shader)
57
58
59def write_shader(shader_name, shader_code, output_dir, outfile):
60 if output_dir:
61 wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
62 with open(wgsl_filename, "w", encoding="utf-8") as f_out:
63 f_out.write(shader_code)
64 outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
65
66
67def generate_variants(fname, input_dir, output_dir, outfile):
68 shader_path = os.path.join(input_dir, fname)
69 shader_base_name = fname.split(".")[0]
70
71 with open(shader_path, "r", encoding="utf-8") as f:
72 text = f.read()
73
74 try:
75 variants = ast.literal_eval(extract_block(text, "VARIANTS"))
76 except ValueError:
77 write_shader(shader_base_name, text, output_dir, outfile)
78 else:
79 try:
80 decls_map = parse_decls(extract_block(text, "DECLS"))
81 except ValueError:
82 decls_map = {}
83 try:
84 templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
85 except ValueError:
86 templates_map = {}
87
88 for fname in sorted(os.listdir(input_dir)):
89 if fname.endswith(".tmpl"):
90 tmpl_path = os.path.join(input_dir, fname)
91 with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
92 decls = f_tmpl.read()
93 decls_map.update(parse_decls(decls))
94
95 shader_template = extract_block(text, "SHADER")
96 for variant in variants:
97 if "DECLS" in variant:
98 decls = variant["DECLS"]
99 else:
100 decls = []
101 decls_code = ""
102 for key in decls:
103 if key not in decls_map:
104 raise ValueError(f"DECLS key '{key}' not found.")
105 decls_code += decls_map[key] + "\n\n"
106 final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
107 if "REPLS" in variant:
108 variant = replace_repl_placeholders(variant, templates_map)
109 final_shader = replace_placeholders(final_shader, variant["REPLS"])
110 # second run to expand placeholders in repl_template
111 final_shader = replace_placeholders(final_shader, variant["REPLS"])
112 final_shader = expand_includes(final_shader, input_dir)
113
114 if "SHADER_NAME" in variant:
115 output_name = variant["SHADER_NAME"]
116 elif "SHADER_SUFFIX" in variant:
117 output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
118 elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
119 output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
120 elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
121 output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
122 elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
123 output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
124 else:
125 output_name = shader_base_name
126 write_shader(output_name, final_shader, output_dir, outfile)
127
128
129def main():
130 parser = argparse.ArgumentParser()
131 parser.add_argument("--input_dir", required=True)
132 parser.add_argument("--output_file", required=True)
133 parser.add_argument("--output_dir")
134 args = parser.parse_args()
135
136 if args.output_dir:
137 os.makedirs(args.output_dir, exist_ok=True)
138
139 with open(args.output_file, "w", encoding="utf-8") as out:
140 out.write("// Auto-generated shader embedding\n\n")
141 for fname in sorted(os.listdir(args.input_dir)):
142 if fname.endswith(".wgsl"):
143 generate_variants(fname, args.input_dir, args.output_dir, out)
144
145
146if __name__ == "__main__":
147 main()