summaryrefslogtreecommitdiff
path: root/llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
blob: d61df5bb9e544253be94a6399aaba03c2cd2f9b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import re
import ast
import argparse


def extract_block(text, name):
    pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
    match = re.search(pattern, text, re.DOTALL)
    if not match:
        raise ValueError(f"Missing block: {name}")
    return match.group(1).strip()


def parse_decls(decls_text):
    decls = {}
    for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
        decls[name.strip()] = code.strip()
    return decls


def replace_repl_placeholders(variant, template_map):
    for repl, code in variant["REPLS"].items():
        for key, val in template_map.items():
            # Match "key" and avoid matching subsequences using by using \b
            code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code)
        variant["REPLS"][repl] = code
    return variant


def replace_placeholders(shader_text, replacements):
    for key, val in replacements.items():
        # Match {{KEY}} literally, where KEY is escaped
        pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
        shader_text = re.sub(pattern, str(val), shader_text)
    return shader_text


def expand_includes(shader, input_dir):
    """
    Replace #include "file" lines in the text with the contents of that file.
    Searches for files relative to input_dir.
    """
    include_pattern = re.compile(r'^\s*#include\s+"([^"]+)"\s*$', re.MULTILINE)

    def replacer(match):
        fname = match.group(1)
        file_path = os.path.join(input_dir, fname)
        if not os.path.exists(file_path):
            raise FileNotFoundError(f"Included file not found: {file_path}")
        with open(file_path, "r", encoding="utf-8") as f:
            included_code = f.read()
        # Recursively expand includes inside the included file
        return expand_includes(included_code, input_dir)

    return include_pattern.sub(replacer, shader)


def write_shader(shader_name, shader_code, output_dir, outfile):
    if output_dir:
        wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
        with open(wgsl_filename, "w", encoding="utf-8") as f_out:
            f_out.write(shader_code)
    outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')


def generate_variants(fname, input_dir, output_dir, outfile):
    shader_path = os.path.join(input_dir, fname)
    shader_base_name = fname.split(".")[0]

    with open(shader_path, "r", encoding="utf-8") as f:
        text = f.read()

    try:
        variants = ast.literal_eval(extract_block(text, "VARIANTS"))
    except ValueError:
        write_shader(shader_base_name, text, output_dir, outfile)
    else:
        try:
            decls_map = parse_decls(extract_block(text, "DECLS"))
        except ValueError:
            decls_map = {}
        try:
            templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES"))
        except ValueError:
            templates_map = {}

        for fname in sorted(os.listdir(input_dir)):
            if fname.endswith(".tmpl"):
                tmpl_path = os.path.join(input_dir, fname)
                with open(tmpl_path, "r", encoding="utf-8") as f_tmpl:
                    decls = f_tmpl.read()
                    decls_map.update(parse_decls(decls))

        shader_template = extract_block(text, "SHADER")
        for variant in variants:
            if "DECLS" in variant:
                decls = variant["DECLS"]
            else:
                decls = []
            decls_code = ""
            for key in decls:
                if key not in decls_map:
                    raise ValueError(f"DECLS key '{key}' not found.")
                decls_code += decls_map[key] + "\n\n"
            final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template)
            if "REPLS" in variant:
                variant = replace_repl_placeholders(variant, templates_map)
                final_shader = replace_placeholders(final_shader, variant["REPLS"])
                # second run to expand placeholders in repl_template
                final_shader = replace_placeholders(final_shader, variant["REPLS"])
            final_shader = expand_includes(final_shader, input_dir)

            if "SHADER_NAME" in variant:
                output_name = variant["SHADER_NAME"]
            elif "SHADER_SUFFIX" in variant:
                output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"]
            elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]:
                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
            elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]:
                output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]])
            elif "REPLS" in variant and "TYPE" in variant["REPLS"]:
                output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"]
            else:
                output_name = shader_base_name
            write_shader(output_name, final_shader, output_dir, outfile)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", required=True)
    parser.add_argument("--output_file", required=True)
    parser.add_argument("--output_dir")
    args = parser.parse_args()

    if args.output_dir:
        os.makedirs(args.output_dir, exist_ok=True)

    with open(args.output_file, "w", encoding="utf-8") as out:
        out.write("// Auto-generated shader embedding\n\n")
        for fname in sorted(os.listdir(args.input_dir)):
            if fname.endswith(".wgsl"):
                generate_variants(fname, args.input_dir, args.output_dir, out)


if __name__ == "__main__":
    main()