1import os
  2import re
  3import ast
  4import argparse
  5
  6
  7def extract_block(text, name):
  8    pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
  9    match = re.search(pattern, text, re.DOTALL)
 10    if not match:
 11        raise ValueError(f"Missing block: {name}")
 12    return match.group(1).strip()
 13
 14
 15def parse_decls(decls_text):
 16    decls = {}
 17    for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
 18        decls[name.strip()] = code.strip()
 19    return decls
 20
 21
 22def replace_repl_placeholders(variant, template_map):
 23    for repl, code in variant["REPLS"].items():
 24        for key, val in template_map.items():
 25            # Match "key" and avoid matching subsequences using by using \b
 26            code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
 27        variant["REPLS"][repl] = code
 28    return variant
 29
 30
 31def replace_placeholders(shader_text, replacements):
 32    for key, val in replacements.items():
 33        # Match {{KEY}} literally, where KEY is escaped
 34        pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
 35        shader_text = re.sub(pattern, str(val), shader_text)
 36    return shader_text
 37
 38
 39def expand_includes(shader, input_dir):
 40    """
 41    Replace #include "file" lines in the text with the contents of that file.
 42    Searches for files relative to input_dir.
 43    """
 44    include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)
 45
 46    def replacer(match):
 47        fname = match.group(1)
 48        file_path = os.path.join(input_dir, fname)
 49        if not os.path.exists(file_path):
 50            raise FileNotFoundError(f"Included file not found: {file_path}")
 51        with open(file_path, "r", encoding="utf-8") as f:
 52            included_code = f.read()
 53        # Recursively expand includes inside the included file
 54        return expand_includes(included_code, input_dir)
 55
 56    return include_pattern.sub(replacer, shader)
 57
 58
 59def write_shader(shader_name, shader_code, output_dir, outfile):
 60    if output_dir:
 61        wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
 62        with open(wgsl_filename, "w", encoding="utf-8") as f_out:
 63            f_out.write(shader_code)
 64    outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
 65
 66
 67def generate_variants(fname, input_dir, output_dir, outfile):
 68    shader_path = os.path.join(input_dir, fname)
 69    shader_base_name = fname.split(".")[0]
 70
 71    with open(shader_path, "r", encoding="utf-8") as f:
 72        text = f.read()
 73
 74    try:
 75        variants = ast.literal_eval(extract_block(text, "VARIANTS"))
 76    except ValueError:
 77        write_shader(shader_base_name, text, output_dir, outfile)
 78    else:
 79        try:
 80            decls_map = parse_decls(extract_block(text, "DECLS"))
 81        except ValueError:
 82            decls_map = {}
 83        try:
 84            templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
 85        except ValueError:
 86            templates_map = {}
 87
 88        for fname in sorted(os.listdir(input_dir)):
 89            if fname.endswith(".tmpl"):
 90                tmpl_path = os.path.join(input_dir, fname)
 91                with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
 92                    decls = f_tmpl.read()
 93                    decls_map.update(parse_decls(decls))
 94
 95        shader_template = extract_block(text, "SHADER")
 96        for variant in variants:
 97            if "DECLS" in variant:
 98                decls = variant["DECLS"]
 99            else:
100                decls = []
101            decls_code = ""
102            for key in decls:
103                if key not in decls_map:
104                    raise ValueError(f"DECLS key '{key}' not found.")
105                decls_code += decls_map[key] + "\n\n"
106            final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
107            if "REPLS" in variant:
108                variant = replace_repl_placeholders(variant, templates_map)
109                final_shader = replace_placeholders(final_shader, variant["REPLS"])
110                # second run to expand placeholders in repl_template
111                final_shader = replace_placeholders(final_shader, variant["REPLS"])
112            final_shader = expand_includes(final_shader, input_dir)
113
114            if "SHADER_NAME" in variant:
115                output_name = variant["SHADER_NAME"]
116            elif "SHADER_SUFFIX" in variant:
117                output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
118            elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
119                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
120            elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
121                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
122            elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
123                output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
124            else:
125                output_name = shader_base_name
126            write_shader(output_name, final_shader, output_dir, outfile)
127
128
129def main():
130    parser = argparse.ArgumentParser()
131    parser.add_argument("--input_dir", required=True)
132    parser.add_argument("--output_file", required=True)
133    parser.add_argument("--output_dir")
134    args = parser.parse_args()
135
136    if args.output_dir:
137        os.makedirs(args.output_dir, exist_ok=True)
138
139    with open(args.output_file, "w", encoding="utf-8") as out:
140        out.write("// Auto-generated shader embedding\n\n")
141        for fname in sorted(os.listdir(args.input_dir)):
142            if fname.endswith(".wgsl"):
143                generate_variants(fname, args.input_dir, args.output_dir, out)
144
145
146if __name__ == "__main__":
147    main()