1#include <iostream>
   2#include <fstream>
   3#include <sstream>
   4#include <string>
   5#include <stdexcept>
   6#include <array>
   7#include <vector>
   8#include <map>
   9#include <thread>
  10#include <mutex>
  11#include <future>
  12#include <queue>
  13#include <condition_variable>
  14#include <cstdio>
  15#include <cstring>
  16#include <cstdlib>
  17#include <cassert>
  18#include <algorithm>
  19#include <sys/stat.h>
  20#include <sys/types.h>
  21#include <filesystem>
  22
  23#ifdef _WIN32
  24    #define NOMINMAX
  25    #include <windows.h>
  26    #include <direct.h> // For _mkdir on Windows
  27#else
  28    #include <unistd.h>
  29    #include <sys/wait.h>
  30    #include <fcntl.h>
  31#endif
  32
  33#define ASYNCIO_CONCURRENCY 64
  34
  35std::mutex lock;
  36std::vector<std::pair<std::string, std::string>> shader_fnames;
  37std::locale c_locale("C");
  38
  39std::string GLSLC = "glslc";
  40std::string input_filepath = "";
  41std::string output_dir = "/tmp";
  42std::string target_hpp = "";
  43std::string target_cpp = "";
  44
  45const std::vector<std::string> type_names = {
  46    "f32",
  47    "f16",
  48    "q4_0",
  49    "q4_1",
  50    "q5_0",
  51    "q5_1",
  52    "q8_0",
  53    "q2_k",
  54    "q3_k",
  55    "q4_k",
  56    "q5_k",
  57    "q6_k",
  58    "iq1_s",
  59    "iq1_m",
  60    "iq2_xxs",
  61    "iq2_xs",
  62    "iq2_s",
  63    "iq3_xxs",
  64    "iq3_s",
  65    "iq4_xs",
  66    "iq4_nl",
  67    "mxfp4",
  68    "bf16",
  69};
  70
  71enum MatMulIdType {
  72    NONE,
  73    DEFAULT,
  74    SUBGROUP,
  75};
  76
  77namespace {
  78
  79void execute_command(std::vector<std::string>& command, std::string& stdout_str, std::string& stderr_str) {
  80#ifdef _WIN32
  81    HANDLE stdout_read, stdout_write;
  82    HANDLE stderr_read, stderr_write;
  83    SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE };
  84
  85    if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) ||
  86        !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) {
  87        throw std::runtime_error("Failed to create stdout pipe");
  88    }
  89
  90    if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) ||
  91        !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) {
  92        throw std::runtime_error("Failed to create stderr pipe");
  93    }
  94
  95    PROCESS_INFORMATION pi;
  96    STARTUPINFOA si = {};
  97    si.cb = sizeof(STARTUPINFOA);
  98    si.dwFlags = STARTF_USESTDHANDLES;
  99    si.hStdOutput = stdout_write;
 100    si.hStdError = stderr_write;
 101
 102    std::string cmd;
 103    for (const auto& part : command) {
 104        cmd += part + " ";
 105    }
 106
 107    if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) {
 108        throw std::runtime_error("Failed to create process");
 109    }
 110
 111    CloseHandle(stdout_write);
 112    CloseHandle(stderr_write);
 113
 114    std::array<char, 128> buffer;
 115    DWORD bytes_read;
 116
 117    while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
 118        stdout_str.append(buffer.data(), bytes_read);
 119    }
 120
 121    while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) {
 122        stderr_str.append(buffer.data(), bytes_read);
 123    }
 124
 125    CloseHandle(stdout_read);
 126    CloseHandle(stderr_read);
 127    WaitForSingleObject(pi.hProcess, INFINITE);
 128    CloseHandle(pi.hProcess);
 129    CloseHandle(pi.hThread);
 130#else
 131    int stdout_pipe[2];
 132    int stderr_pipe[2];
 133
 134    if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
 135        throw std::runtime_error("Failed to create pipes");
 136    }
 137
 138    pid_t pid = fork();
 139    if (pid < 0) {
 140        throw std::runtime_error("Failed to fork process");
 141    }
 142
 143    std::vector<char*> argv;
 144    for (std::string& part : command) {
 145        argv.push_back(part.data());
 146    }
 147    argv.push_back(nullptr);
 148
 149    if (pid == 0) {
 150        close(stdout_pipe[0]);
 151        close(stderr_pipe[0]);
 152        dup2(stdout_pipe[1], STDOUT_FILENO);
 153        dup2(stderr_pipe[1], STDERR_FILENO);
 154        close(stdout_pipe[1]);
 155        close(stderr_pipe[1]);
 156        execvp(argv[0], argv.data());
 157        _exit(EXIT_FAILURE);
 158    } else {
 159        close(stdout_pipe[1]);
 160        close(stderr_pipe[1]);
 161
 162        std::array<char, 128> buffer;
 163        ssize_t bytes_read;
 164
 165        while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) {
 166            stdout_str.append(buffer.data(), bytes_read);
 167        }
 168
 169        while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) {
 170            stderr_str.append(buffer.data(), bytes_read);
 171        }
 172
 173        close(stdout_pipe[0]);
 174        close(stderr_pipe[0]);
 175        waitpid(pid, nullptr, 0);
 176    }
 177#endif
 178}
 179
 180bool directory_exists(const std::string& path) {
 181    struct stat info;
 182    if (stat(path.c_str(), &info) != 0) {
 183        return false; // Path doesn't exist or can't be accessed
 184    }
 185    return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory
 186}
 187
 188bool create_directory(const std::string& path) {
 189#ifdef _WIN32
 190    return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists
 191#else
 192    return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions
 193#endif
 194}
 195
 196std::string to_uppercase(const std::string& input) {
 197    std::string result = input;
 198    for (char& c : result) {
 199        c = std::toupper(c);
 200    }
 201    return result;
 202}
 203
 204bool string_starts_with(const std::string& str, const std::string& prefix) {
 205    if (prefix.size() > str.size()) {
 206        return false;
 207    }
 208    return std::equal(prefix.begin(), prefix.end(), str.begin());
 209}
 210
 211bool string_ends_with(const std::string& str, const std::string& suffix) {
 212    if (suffix.size() > str.size()) {
 213        return false;
 214    }
 215    return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
 216}
 217
 218bool is_quantized_type(const std::string& type_name) {
 219    return type_name != "f32" && type_name != "f16" && type_name != "bf16";
 220}
 221
 222bool is_legacy_quant(const std::string& type_name) {
 223    return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
 224}
 225
 226bool is_k_quant(const std::string& type_name) {
 227    return string_ends_with(type_name, "_k");
 228}
 229
 230bool is_iq_quant(const std::string& type_name) {
 231    return string_starts_with(type_name, "iq");
 232}
 233
 234static const char path_separator = '/';
 235
 236std::string join_paths(const std::string& path1, const std::string& path2) {
 237    return path1 + path_separator + path2;
 238}
 239
 240std::string basename(const std::string &path) {
 241    return path.substr(path.find_last_of("/\\") + 1);
 242}
 243
 244std::stringstream make_generic_stringstream() {
 245    std::stringstream ss;
 246    ss.imbue(c_locale);
 247    return ss;
 248}
 249
 250std::string read_binary_file(const std::string& path, bool may_not_exist = false) {
 251    FILE* f = fopen(path.c_str(), "rb");
 252    if (!f) {
 253        if (!may_not_exist) {
 254            std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n";
 255        }
 256        return {};
 257    }
 258
 259    fseek(f, 0, SEEK_END);
 260    size_t size = ftell(f);
 261    fseek(f, 0, SEEK_SET);
 262
 263    std::string data(size, '\0');
 264    size_t read_size = fread(data.data(), 1, size, f);
 265    fclose(f);
 266    if (read_size != size) {
 267        std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n";
 268        return {};
 269    }
 270
 271    return data;
 272}
 273
 274void write_binary_file(const std::string& path, const std::string& content) {
 275    FILE* f = fopen(path.c_str(), "wb");
 276    if (!f) {
 277        std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n";
 278        return;
 279    }
 280
 281    size_t write_size = fwrite(content.data(), 1, content.size(), f);
 282    fclose(f);
 283    if (write_size != content.size()) {
 284        std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n";
 285        return;
 286    }
 287}
 288
 289void write_file_if_changed(const std::string& path, const std::string& content) {
 290    std::string existing = read_binary_file(path, true);
 291    if (existing != content) {
 292        write_binary_file(path, content);
 293    }
 294}
 295
 296
 297// variables to track number of compiles in progress
 298static uint32_t compile_count = 0;
 299static std::mutex compile_count_mutex;
 300static std::condition_variable compile_count_cond;
 301static bool generate_dep_file = true;
 302
 303void decrement_compile_count(uint32_t * count) {
 304    if (count) {
 305        std::lock_guard<std::mutex> guard(compile_count_mutex);
 306        assert(compile_count > 0);
 307        compile_count--;
 308        compile_count_cond.notify_all();
 309    }
 310}
 311
 312using compile_count_guard = std::unique_ptr<uint32_t, decltype(&decrement_compile_count)>;
 313
 314compile_count_guard acquire_compile_slot() {
 315    // wait until fewer than N compiles are in progress.
 316    // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors.
 317    uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency()));
 318    std::unique_lock<std::mutex> guard(compile_count_mutex);
 319    compile_count_cond.wait(guard, [N] { return compile_count < N; });
 320    compile_count++;
 321    return compile_count_guard(&compile_count, &decrement_compile_count);
 322}
 323
 324void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map<std::string, std::string> defines, bool coopmat, bool dep_file, compile_count_guard slot) {
 325    std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2";
 326
 327    #ifdef _WIN32
 328        std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""};
 329    #else
 330        std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path};
 331    #endif
 332
 333    // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734
 334    // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344
 335    // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860
 336    if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) {
 337        cmd.push_back("-O");
 338    }
 339
 340    if (dep_file) {
 341        cmd.push_back("-MD");
 342        cmd.push_back("-MF");
 343#ifdef _WIN32
 344        cmd.push_back("\"" + target_cpp + ".d\"");
 345#else
 346        cmd.push_back(target_cpp + ".d");
 347#endif
 348    }
 349
 350    #ifdef GGML_VULKAN_SHADER_DEBUG_INFO
 351        cmd.push_back("-g");
 352    #endif
 353
 354    for (const auto& define : defines) {
 355        cmd.push_back("-D" + define.first + "=" + define.second);
 356    }
 357
 358    std::string command;
 359    for (const auto& part : cmd) {
 360        command += part + " ";
 361    }
 362
 363    std::string stdout_str, stderr_str;
 364    try {
 365        // std::cout << "Executing command: ";
 366        // for (const auto& part : cmd) {
 367        //     std::cout << part << " ";
 368        // }
 369        // std::cout << std::endl;
 370
 371        execute_command(cmd, stdout_str, stderr_str);
 372        if (!stderr_str.empty()) {
 373            std::cerr << "cannot compile " << name << "\n\n";
 374            for (const auto& part : cmd) {
 375                std::cerr << part << " ";
 376            }
 377            std::cerr << "\n\n" << stderr_str << std::endl;
 378            return;
 379        }
 380
 381        if (dep_file) {
 382            // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt
 383            std::string dep = read_binary_file(target_cpp + ".d", true);
 384            if (!dep.empty()) {
 385                size_t pos = dep.find(out_path);
 386                if (pos != std::string::npos) {
 387                    dep.replace(pos, out_path.length(), target_cpp);
 388                }
 389                write_binary_file(target_cpp + ".d", dep);
 390            }
 391        }
 392
 393        std::lock_guard<std::mutex> guard(lock);
 394        shader_fnames.push_back(std::make_pair(name, out_path));
 395    } catch (const std::exception& e) {
 396        std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl;
 397    }
 398}
 399
 400std::map<std::string, std::string> merge_maps(const std::map<std::string, std::string>& a, const std::map<std::string, std::string>& b) {
 401    std::map<std::string, std::string> result = a;
 402    result.insert(b.begin(), b.end());
 403    return result;
 404}
 405
 406static std::vector<std::future<void>> compiles;
 407void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) {
 408    name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32"));
 409    std::string out_path = join_paths(output_dir, name + ".spv");
 410
 411    if (input_filepath == "") {
 412        // No input source to compile, only generate header for all shaders
 413        shader_fnames.push_back(std::pair(name, out_path));
 414        return;
 415    } else if (basename(input_filepath) != source) {
 416        // Only compile shader variants matching the input filename
 417        return;
 418    }
 419
 420    compile_count_guard slot = acquire_compile_slot();
 421    compiles.push_back(std::async(
 422        string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot)));
 423    // Don't write the same dep file from multiple processes
 424    generate_dep_file = false;
 425}
 426
 427void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) {
 428    std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4";
 429    std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4";
 430    std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4";
 431
 432    std::map<std::string, std::string> base_dict;
 433    std::string shader_name = "matmul";
 434
 435    if (matmul_id_type == MatMulIdType::DEFAULT) {
 436        base_dict["MUL_MAT_ID"] = "1";
 437        shader_name = "matmul_id";
 438    } else if (matmul_id_type == MatMulIdType::SUBGROUP) {
 439        base_dict["MUL_MAT_ID"] = "1";
 440        base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1";
 441        shader_name = "matmul_id_subgroup";
 442    }
 443
 444    if (fp16) {
 445        base_dict["FLOAT16"] = "1";
 446    }
 447
 448    base_dict["ACC_TYPE"     ] = f16acc ? "float16_t" : "float";
 449    base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2"   : "vec2";
 450    if (f16acc) {
 451        base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
 452    }
 453
 454    if (coopmat) {
 455        base_dict["COOPMAT"] = "1";
 456    }
 457
 458    const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp";
 459
 460    auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string {
 461        switch (vec) {
 462        case 1:
 463            if (t == "bf16") {
 464                // scalar path promotes to float
 465                if (!coopmat && !coopmat2) {
 466                    return "float";
 467                }
 468                return "bfloat16_t";
 469            }
 470            if (coopmat2 || fp16) {
 471                return "float16_t";
 472            }
 473            return "float";
 474        case 2:
 475            if (t == "bf16") {
 476                // scalar path promotes to float
 477                if (!coopmat && !coopmat2) {
 478                    return "vec2";
 479                }
 480                return "bf16vec2";
 481            }
 482            if (coopmat2 || fp16) {
 483                return "f16vec2";
 484            }
 485            return "vec2";
 486        case 4:
 487            if (t == "bf16") {
 488                // scalar path promotes to float
 489                if (!coopmat && !coopmat2) {
 490                    return "vec4";
 491                }
 492                return "bf16vec4";
 493            }
 494            if (coopmat2 || fp16) {
 495                return "f16vec4";
 496            }
 497            return "vec4";
 498        case 8:
 499            if (t == "bf16") {
 500                // scalar path promotes to float
 501                if (!coopmat && !coopmat2) {
 502                    return "mat2x4";
 503                }
 504                throw std::runtime_error("bf16 vec8 not supported");
 505            }
 506            if (coopmat2 || fp16) {
 507                return "f16mat2x4";
 508            }
 509            return "mat2x4";
 510        default:
 511            throw std::runtime_error("invalid vector size");
 512        }
 513    };
 514
 515    const std::map<std::string, std::string> float_type_dict_f16 = {
 516        {"FLOAT_TYPE",      FLOAT_TYPE(1, "f16")},
 517        {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")},
 518        {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")},
 519        {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")},
 520    };
 521
 522    // Shaders with f16 B_TYPE
 523    string_to_spv(shader_name + "_f32_f16",         source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"},                                                     {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
 524    string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 525
 526    string_to_spv(shader_name + "_f16",             source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"},                                                     {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 527    string_to_spv(shader_name + "_f16_aligned",     source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 528
 529    // bf16
 530    {
 531        // For aligned matmul loads
 532        std::string load_vec_a = coopmat2 ? "1" : "4";
 533
 534        // scalar path promotes to float
 535        std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32";
 536
 537        const std::map<std::string, std::string> float_type_dict_bf16 = {
 538            {"FLOAT_TYPE",      FLOAT_TYPE(1, "bf16")},
 539            {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")},
 540            {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")},
 541        };
 542
 543        // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader
 544#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 545        if (!(coopmat || coopmat2))
 546#endif
 547        {
 548            string_to_spv(shader_name + "_bf16",         source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"},                             {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}),                   fp16, coopmat, coopmat2, f16acc);
 549            string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"},  {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 550        }
 551    }
 552
 553    for (const auto& tname : type_names) {
 554        std::string load_vec_quant = "2";
 555        if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
 556            load_vec_quant = "8";
 557        else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
 558            load_vec_quant = "4";
 559
 560        if (tname == "bf16") {
 561            continue;
 562        }
 563
 564        std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 565        // For unaligned, load one at a time for f32/f16, or two at a time for quants
 566        std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant;
 567        // For aligned matmul loads
 568        std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant;
 569
 570        const std::map<std::string, std::string> float_type_dict = {
 571            {"FLOAT_TYPE",      FLOAT_TYPE(1, tname)},
 572            {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)},
 573            {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)},
 574            {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)},
 575        };
 576
 577        // don't generate f32 variants for coopmat2
 578        if (!coopmat2) {
 579            string_to_spv(shader_name + "_" + tname + "_f32",         source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float"},            {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 580            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 581        }
 582
 583        if (tname != "f16" && tname != "f32") {
 584            string_to_spv(shader_name + "_" + tname + "_f16",         source_name,  merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
 585            string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name,  merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a},           {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
 586        }
 587
 588#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 589        // Integer dot mmq performs better with f32 accumulators
 590        if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) {
 591            string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
 592        }
 593#endif
 594    }
 595}
 596
 597void process_shaders() {
 598    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
 599
 600    // matmul
 601    for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
 602        // No coopmats
 603        // fp32
 604        matmul_shaders(false, matmul_id_type, false, false, false);
 605
 606        // fp16, fp32acc and fp16acc
 607        matmul_shaders(true, matmul_id_type, false, false, false);
 608        matmul_shaders(true, matmul_id_type, false, false, true);
 609
 610        if (matmul_id_type != MatMulIdType::DEFAULT) {
 611#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 612            // Coopmat, fp32acc and fp16acc
 613            matmul_shaders(true, matmul_id_type, true, false, false);
 614            matmul_shaders(true, matmul_id_type, true, false, true);
 615#endif
 616
 617#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 618            // Coopmat2, fp32acc and fp16acc
 619            matmul_shaders(true, matmul_id_type, false, true, false);
 620            matmul_shaders(true, matmul_id_type, false, true, true);
 621#endif
 622        }
 623    }
 624
 625    // flash attention
 626    for (const auto& f16acc : {false, true}) {
 627        std::map<std::string, std::string> fa_base_dict = base_dict;
 628        fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
 629        fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
 630        if (f16acc) {
 631            fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
 632        }
 633
 634        for (const auto& tname : type_names) {
 635            if (tname == "bf16") continue;
 636
 637#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 638            if (tname == "f16") {
 639                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
 640                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
 641            } else {
 642                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 643                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
 644                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
 645            }
 646#endif
 647#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 648            if (tname == "f16") {
 649                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
 650                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
 651            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
 652                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 653                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
 654                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
 655            }
 656#endif
 657            if (tname == "f16") {
 658                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
 659                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
 660            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
 661                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 662                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
 663                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
 664            }
 665        }
 666    }
 667
 668    for (const auto& tname : type_names) {
 669        // mul mat vec
 670        std::string data_a_key = "DATA_A_" + to_uppercase(tname);
 671        std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
 672
 673        string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
 674        string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
 675
 676        string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
 677        string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
 678
 679        string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
 680        string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
 681
 682        string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
 683        string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
 684        string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
 685
 686        // mul mat vec with integer dot product
 687#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 688        if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") {
 689            string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
 690            string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
 691            string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
 692
 693            string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
 694            string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
 695            string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
 696        }
 697#endif
 698
 699        // Dequant shaders
 700        if (tname != "f16" && tname != "bf16") {
 701            string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
 702        }
 703
 704        shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp";
 705
 706        if (tname == "f16") {
 707            string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
 708        } else {
 709            string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
 710        }
 711        string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{"TEMP_TYPE", "FLOAT_TYPE"}, {data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
 712    }
 713
 714    string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}});
 715
 716    string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
 717    string_to_spv("mul_mat_vec_p021_f16_f32",              "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
 718    string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}});
 719
 720    // Norms
 721    string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 722    string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 723    string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 724    string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 725    string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}}));
 726    string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}}));
 727    string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 728    string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 729
 730    string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 731    string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
 732    string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 733    string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 734    string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
 735    string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 736    string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
 737    string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
 738    string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
 739    string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 740    string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 741    string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
 742    string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
 743    string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
 744
 745    string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
 746    string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}});
 747
 748    for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
 749        string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 750        string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
 751        string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 752    }
 753
 754    for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
 755        string_to_spv("set_rows_" + t + "_i32",     "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 756        string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
 757        string_to_spv("set_rows_" + t + "_i64",     "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 758        string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
 759    }
 760
 761    auto get_type_str = [](bool f16) {
 762        return f16 ? "float16_t" : "float";
 763    };
 764    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
 765        std::string s;
 766        s += std::string(src0_f16 ? "_f16" : "_f32");
 767        s += std::string(src1_f16 ? "_f16" : "_f32");
 768        s += std::string(dst_f16 ? "_f16" : "_f32");
 769        return s;
 770    };
 771    for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
 772    for (auto src0_f16 : {false, true}) {
 773    for (auto src1_f16 : {false, true}) {
 774    for (auto dst_f16  : {false, true}) {
 775    for (auto rte      : {false, true}) {
 776        auto source = op == "add_rms" ? std::string("add") : op;
 777        auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
 778        auto add_rms = op == "add_rms" ? "1" : "0";
 779        string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
 780    }
 781    }
 782    }
 783    }
 784    }
 785
 786    string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 787
 788    string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 789
 790    string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
 791    string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
 792
 793    string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {});
 794
 795    string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
 796    string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
 797
 798    string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
 799    string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
 800
 801    string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 802
 803    string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 804
 805    string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 806    string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 807
 808    string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 809
 810    string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 811
 812    string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 813
 814    string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 815
 816    string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 817
 818    string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 819
 820    string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 821
 822    string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 823    string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 824    string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}});
 825
 826    string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 827
 828    for (auto rte : {false, true}) {
 829        std::string suffix = rte ? "_rte" : "";
 830        string_to_spv("exp_f16" + suffix,        "exp.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 831        string_to_spv("exp_f32" + suffix,        "exp.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}    ,   {"RTE16", rte ? "1" : "0"}});
 832
 833        string_to_spv("log_f16" + suffix,        "log.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 834        string_to_spv("log_f32" + suffix,        "log.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 835    }
 836    string_to_spv("gelu_f16",       "gelu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 837    string_to_spv("gelu_f32",       "gelu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 838    string_to_spv("gelu_erf_f16",   "gelu_erf.comp",    {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 839    string_to_spv("gelu_erf_f32",   "gelu_erf.comp",    {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 840    string_to_spv("gelu_quick_f16", "gelu_quick.comp",  {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 841    string_to_spv("gelu_quick_f32", "gelu_quick.comp",  {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 842    string_to_spv("silu_f16",       "silu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 843    string_to_spv("silu_f32",       "silu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 844    string_to_spv("relu_f16",       "relu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 845    string_to_spv("relu_f32",       "relu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 846    string_to_spv("neg_f16",        "neg.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 847    string_to_spv("neg_f32",        "neg.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 848    string_to_spv("tanh_f16",       "tanh.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 849    string_to_spv("tanh_f32",       "tanh.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 850    string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 851    string_to_spv("sigmoid_f32",    "sigmoid.comp",     {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 852    string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 853    string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 854    string_to_spv("hardswish_f16",  "hardswish.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 855    string_to_spv("hardswish_f32",  "hardswish.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 856    string_to_spv("abs_f16",        "abs.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 857    string_to_spv("abs_f32",        "abs.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 858    string_to_spv("xielu_f16",      "xielu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 859    string_to_spv("xielu_f32",      "xielu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 860
 861    string_to_spv("tri_f16",        "tri.comp",         {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 862    string_to_spv("tri_f32",        "tri.comp",         {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 863    string_to_spv("diag_f16",       "diag.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 864    string_to_spv("diag_f32",       "diag.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 865
 866    string_to_spv("softplus_f16",   "softplus.comp",    {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 867    string_to_spv("softplus_f32",   "softplus.comp",    {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 868
 869    string_to_spv("add1_f16_f16",   "add1.comp",        {{"A_TYPE", "float16_t"},   {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
 870    string_to_spv("add1_f16_f32",   "add1.comp",        {{"A_TYPE", "float16_t"},   {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
 871    string_to_spv("add1_f32_f32",   "add1.comp",        {{"A_TYPE", "float"},       {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 872    string_to_spv("arange_f32",     "arange.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 873    string_to_spv("fill_f32",       "fill.comp",        {{"D_TYPE", "float"},       {"FLOAT_TYPE", "float"}});
 874    string_to_spv("step_f16",       "step.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 875    string_to_spv("step_f32",       "step.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 876    string_to_spv("round_f16",      "round.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 877    string_to_spv("round_f32",      "round.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 878    string_to_spv("ceil_f16",       "ceil.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 879    string_to_spv("ceil_f32",       "ceil.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 880    string_to_spv("floor_f16",      "floor.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 881    string_to_spv("floor_f32",      "floor.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 882    string_to_spv("trunc_f16",      "trunc.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 883    string_to_spv("trunc_f32",      "trunc.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 884
 885    for (auto rte : {false, true}) {
 886        std::string suffix = rte ? "_rte" : "";
 887        string_to_spv("geglu_f16" + suffix,      "geglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 888        string_to_spv("geglu_f32" + suffix,      "geglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 889        string_to_spv("reglu_f16" + suffix,      "reglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 890        string_to_spv("reglu_f32" + suffix,      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 891        string_to_spv("swiglu_f16" + suffix,     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 892        string_to_spv("swiglu_f32" + suffix,     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 893        string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp",  {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 894        string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp",  {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 895        string_to_spv("geglu_erf_f16" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 896        string_to_spv("geglu_erf_f32" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 897        string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
 898        string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
 899    }
 900
 901    string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 902    string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 903
 904    string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 905
 906    string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 907    string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
 908    string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 909
 910    string_to_spv("soft_max_large1_f32", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 911    string_to_spv("soft_max_large2_f32", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 912    string_to_spv("soft_max_large3_f32", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 913    string_to_spv("soft_max_large1_f32_f16", "soft_max_large1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
 914    string_to_spv("soft_max_large2_f32_f16", "soft_max_large2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
 915    string_to_spv("soft_max_large3_f32_f16", "soft_max_large3.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
 916
 917    string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
 918    string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
 919    string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 920    string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
 921    string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 922
 923    string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
 924    string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
 925    string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 926    string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
 927    string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 928
 929    string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
 930    string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
 931    string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 932    string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}});
 933    string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 934
 935    string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}});
 936    string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}});
 937    string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}});
 938
 939    string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
 940    string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
 941
 942    string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
 943    string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
 944
 945    string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
 946    string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 947    string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
 948    string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 949    string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 950    string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 951
 952    string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
 953
 954    for (std::string dim_str : {"", "_3d"}) {
 955        for (bool bda : {false, true}) {
 956            std::string bda_str = bda ? "_bda" : "";
 957            std::string bda_def = bda ? "1" : "0";
 958            string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}}));
 959            string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}}));
 960            string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}}));
 961        }
 962    }
 963
 964    string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 965
 966    string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"},  {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 967
 968    string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 969
 970    string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 971
 972    string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 973
 974    string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 975    string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
 976
 977    string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
 978
 979    for (auto transpose : {false, true}) {
 980        for (auto unroll : {false, true}) {
 981            for (auto a_f16 : {false, true}) {
 982                std::map<std::string, std::string> defines = {
 983                    {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"},
 984                    {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""},
 985                };
 986                if (transpose) defines["TRANSPOSE"] = "1";
 987                std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d")
 988                    + (a_f16 ? "_f16" : "") + "_f32";
 989                string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines);
 990#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 991                if (unroll) {
 992                    defines["COOPMAT2"] = "1";
 993                    string_to_spv(name, "conv2d_mm.comp", defines, true, false, true);
 994                }
 995#endif
 996            }
 997        }
 998    }
 999
1000    string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
1001    string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
1002    string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
1003    string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
1004
1005    string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
1006
1007    string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
1008
1009    string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
1010    string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
1011
1012    string_to_spv("ssm_scan_f32",          "ssm_scan.comp", {{"A_TYPE", "float"}});
1013    string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
1014
1015    string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
1016
1017    string_to_spv("topk_moe_f32", "topk_moe.comp", {});
1018
1019    for (auto &c : compiles) {
1020        c.wait();
1021    }
1022}
1023
1024void write_output_files() {
1025    std::stringstream hdr = make_generic_stringstream();
1026    std::stringstream src = make_generic_stringstream();
1027
1028    hdr << "#include <cstdint>\n\n";
1029    src << "#include \"" << basename(target_hpp) << "\"\n\n";
1030
1031    std::sort(shader_fnames.begin(), shader_fnames.end());
1032    for (const auto& pair : shader_fnames) {
1033        const std::string& name = pair.first;
1034        #ifdef _WIN32
1035            std::string path = pair.second;
1036            std::replace(path.begin(), path.end(), '/', '\\' );
1037        #else
1038            const std::string& path = pair.second;
1039        #endif
1040
1041        hdr << "extern const uint64_t " << name << "_len;\n";
1042        hdr << "extern const unsigned char " << name << "_data[];\n\n";
1043
1044        if (input_filepath != "") {
1045            std::string data = read_binary_file(path);
1046            if (data.empty()) {
1047                continue;
1048            }
1049
1050            src << "const uint64_t " << name << "_len = " << data.size() << ";\n";
1051            src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex;
1052            auto bytes = reinterpret_cast<const uint8_t*>(data.data());
1053            for (size_t i = 0; i < data.size(); ++i) {
1054                src << "0x" << static_cast<int>(bytes[i]) << ",";
1055                if ((i + 1) % 12 == 0) src << "\n";
1056            }
1057            src << std::dec << "\n};\n\n";
1058        }
1059    }
1060
1061    std::string suffixes[2] = {"_f32", "_f16"};
1062    for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) {
1063        hdr << "extern const void * " << op << "_data[2][2][2][2];\n";
1064        hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n";
1065
1066        std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp";
1067        if (basename(input_filepath) != op_file) {
1068            continue;
1069        }
1070        std::stringstream data = make_generic_stringstream();
1071        std::stringstream len  = make_generic_stringstream();
1072        data << "const void * " << op << "_data[2][2][2][2] = ";
1073        len  << "const uint64_t " << op << "_len[2][2][2][2] = ";
1074        for (uint32_t t0 = 0; t0 < 2; ++t0) {
1075            if (t0 == 0) {
1076                data << "{";
1077                len  << "{";
1078            }
1079            for (uint32_t t1 = 0; t1 < 2; ++t1) {
1080                if (t1 == 0) {
1081                    data << "{";
1082                    len  << "{";
1083                }
1084                for (uint32_t t2 = 0; t2 < 2; ++t2) {
1085                    if (t2 == 0) {
1086                        data << "{";
1087                        len  << "{";
1088                    }
1089                    for (uint32_t rte = 0; rte < 2; ++rte) {
1090                        if (rte == 0) {
1091                            data << "{";
1092                            len  << "{";
1093                        }
1094                        data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
1095                        len  << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : "");
1096                        data << "_data,";
1097                        len  << "_len,";
1098                        if (rte == 1) {
1099                            data << "}, ";
1100                            len  << "}, ";
1101                        }
1102                    }
1103                    if (t2 == 1) {
1104                        data << "}, ";
1105                        len  << "}, ";
1106                    }
1107                }
1108                if (t1 == 1) {
1109                    data << "}, ";
1110                    len  << "}, ";
1111                }
1112            }
1113            if (t0 == 1) {
1114                data << "};\n";
1115                len  << "};\n";
1116            }
1117        }
1118        src << data.str();
1119        src << len.str();
1120    }
1121
1122    std::vector<std::string> btypes = {"f16", "f32"};
1123
1124#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
1125    btypes.push_back("q8_1");
1126#endif
1127
1128    for (const std::string& btype : btypes) {
1129    for (const auto& tname : type_names) {
1130        if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") {
1131            continue;
1132        }
1133        hdr << "extern const void * arr_dmmv_"   << tname << "_" << btype << "_f32_data[3];\n";
1134        hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n";
1135        if (basename(input_filepath) == "mul_mat_vec.comp") {
1136            src << "const void * arr_dmmv_"   << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
1137            src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] =  {mul_mat_vec_" << tname << "_" << btype << "_f32_len,  mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_"  << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
1138        }
1139
1140        if (btype == "f16") {
1141            continue;
1142        }
1143        hdr << "extern const void * arr_dmmv_id_"   << tname << "_" << btype << "_f32_data[3];\n";
1144        hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
1145        if (basename(input_filepath) == "mul_mat_vec.comp") {
1146            src << "const void * arr_dmmv_id_"   << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
1147            src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] =  {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len,  mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_"  << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
1148        }
1149    }
1150    }
1151
1152    if (input_filepath == "") {
1153        write_file_if_changed(target_hpp, hdr.str());
1154    }
1155    if (target_cpp != "") {
1156        write_binary_file(target_cpp, src.str());
1157    }
1158}
1159
1160} // namespace
1161
1162int main(int argc, char** argv) {
1163    std::map<std::string, std::string> args;
1164    for (int i = 1; i < argc; ++i) {
1165        std::string arg = argv[i];
1166        if (arg.rfind("--", 0) == 0) {
1167            if (i + 1 < argc && argv[i + 1][0] != '-') {
1168                args[arg] = argv[i + 1];
1169                ++i;
1170            } else {
1171                args[arg] = "";
1172            }
1173        }
1174    }
1175
1176    if (args.find("--glslc") != args.end()) {
1177        GLSLC = args["--glslc"]; // Path to glslc
1178    }
1179    if (args.find("--source") != args.end()) {
1180        input_filepath = args["--source"]; // The shader source file to compile
1181    }
1182    if (args.find("--output-dir") != args.end()) {
1183        output_dir = args["--output-dir"]; // Directory for containing SPIR-V output
1184    }
1185    if (args.find("--target-hpp") != args.end()) {
1186        target_hpp = args["--target-hpp"]; // Path to generated header file
1187    }
1188    if (args.find("--target-cpp") != args.end()) {
1189        target_cpp = args["--target-cpp"]; // Path to generated cpp file
1190    }
1191
1192    if (!directory_exists(output_dir)) {
1193        if (!create_directory(output_dir)) {
1194            std::cerr << "Error creating output directory: " << output_dir << "\n";
1195            return EXIT_FAILURE;
1196        }
1197    }
1198
1199    process_shaders();
1200
1201    write_output_files();
1202
1203    return EXIT_SUCCESS;
1204}