1#include "ggml-vulkan.h"
    2#include <vulkan/vulkan_core.h>
    3#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
    4#include <chrono>
    5#include "ggml-cpu.h"
    6#endif
    7
    8// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
    9#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
   10// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
   11// to avoid conflicts with applications or other libraries who might use it.
   12#if VK_HEADER_VERSION >= 301
   13namespace vk::detail { class DispatchLoaderDynamic; }
   14using vk::detail::DispatchLoaderDynamic;
   15#else
   16namespace vk { class DispatchLoaderDynamic; }
   17using vk::DispatchLoaderDynamic;
   18#endif
   19DispatchLoaderDynamic & ggml_vk_default_dispatcher();
   20#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
   21
   22#include <vulkan/vulkan.hpp>
   23
   24#include <algorithm>
   25#include <cmath>
   26#include <iomanip>
   27#include <iostream>
   28#include <tuple>
   29#include <vector>
   30#include <sstream>
   31#include <utility>
   32#include <memory>
   33#include <limits>
   34#include <map>
   35#include <set>
   36#include <unordered_map>
   37#include <memory>
   38#include <mutex>
   39#include <future>
   40#include <thread>
   41
   42#if defined(_MSC_VER)
   43# define NOMINMAX 1
   44# include <windows.h>
   45# define YIELD() YieldProcessor()
   46#elif defined(__clang__) || defined(__GNUC__)
   47# if defined(__x86_64__) ||defined(__i386__)
   48#  include <immintrin.h>
   49#  define YIELD() _mm_pause()
   50# elif defined(__arm__) || defined(__aarch64__)
   51#  if defined(__clang__)
   52#   include <arm_acle.h>
   53#   define YIELD() __yield()
   54#  else
   55#   define YIELD() asm volatile("yield")
   56#  endif
   57# endif
   58#endif
   59
   60#if !defined(YIELD)
   61#define YIELD()
   62#endif
   63
   64#include "ggml-impl.h"
   65#include "ggml-backend-impl.h"
   66
   67#include "ggml-vulkan-shaders.hpp"
   68
   69// remove this once it's more widely available in the SDK
   70#if !defined(VK_KHR_shader_bfloat16)
   71
   72#define VK_KHR_shader_bfloat16 1
   73#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION                          1
   74#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME                        "VK_KHR_shader_bfloat16"
   75#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
   76#define VK_COMPONENT_TYPE_BFLOAT16_KHR                               ((VkComponentTypeKHR)1000141000)
   77
   78typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
   79    VkStructureType                       sType;
   80    void*                                 pNext;
   81    VkBool32                              shaderBFloat16Type;
   82    VkBool32                              shaderBFloat16DotProduct;
   83    VkBool32                              shaderBFloat16CooperativeMatrix;
   84} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
   85#endif
   86
   87#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
   88#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
   89static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
   90
   91#define VK_VENDOR_ID_AMD 0x1002
   92#define VK_VENDOR_ID_APPLE 0x106b
   93#define VK_VENDOR_ID_INTEL 0x8086
   94#define VK_VENDOR_ID_NVIDIA 0x10de
   95
   96#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
   97
   98#define GGML_VK_MAX_NODES 8192
   99
  100#define VK_CHECK(err, msg)                                          \
  101    do {                                                            \
  102        vk::Result err_ = (err);                                    \
  103        if (err_ != vk::Result::eSuccess) {                         \
  104            fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n",  \
  105                #err, to_string(err_).c_str(), __FILE__, __LINE__); \
  106            exit(1);                                                \
  107        }                                                           \
  108    } while (0)
  109
  110#ifdef GGML_VULKAN_DEBUG
  111#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
  112#else
  113#define VK_LOG_DEBUG(msg) ((void) 0)
  114#endif // GGML_VULKAN_DEBUG
  115
  116struct ggml_backend_vk_context;
  117
  118#define MAX_PARAMETER_COUNT 12
  119// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
  120#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
  121
  122typedef std::shared_ptr<struct vk_pipeline_struct> vk_pipeline;
  123
  124struct vk_pipeline_struct {
  125    std::string name;
  126    vk::ShaderModule shader_module;
  127    vk::PipelineLayout layout;
  128    vk::Pipeline pipeline;
  129    uint32_t push_constant_size;
  130    uint32_t parameter_count;
  131    std::array<uint32_t, 3> wg_denoms;
  132    uint32_t align;
  133    // true if fields have been set by ggml_vk_create_pipeline
  134    bool initialized {};
  135    // set to true to request the pipeline is compiled
  136    std::atomic<bool> needed {};
  137    // set to true when the shader has been compiled
  138    std::atomic<bool> compiled {};
  139    // number of registers used, extracted from pipeline executable properties
  140    uint32_t register_count {};
  141
  142#if defined(VK_EXT_shader_64bit_indexing)
  143    bool is_64b_indexing {};
  144#endif
  145    // linked list of pipelines for multiple compilation variants.
  146    // currently only used to compile a 64-bit indexing variant.
  147    vk_pipeline next;
  148};
  149
  150typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
  151
  152static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
  153
  154struct vk_matmul_pipeline_struct {
  155    vk_pipeline l, m, s;
  156    vk_pipeline a_l, a_m, a_s;
  157    // Returns true when all unaligned pipelines are null.
  158    // We only check for unaligned variants since one of the unaligned pipelines must exist
  159    // while aligned pipelines are optional
  160    bool is_empty() const {
  161        return l == nullptr && m == nullptr && s == nullptr;
  162    }
  163};
  164typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
  165
  166struct vk_matmul_pipeline2 {
  167    vk_matmul_pipeline2() {
  168        f16acc = std::make_shared<vk_matmul_pipeline_struct>();
  169        f32acc = std::make_shared<vk_matmul_pipeline_struct>();
  170    }
  171    vk_matmul_pipeline f32acc;
  172    vk_matmul_pipeline f16acc;
  173};
  174
  175struct vk_device_struct;
  176typedef std::shared_ptr<vk_device_struct> vk_device;
  177typedef std::weak_ptr<vk_device_struct> vk_device_ref;
  178
  179struct vk_buffer_struct;
  180typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
  181typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
  182
  183struct ggml_backend_vk_buffer_type_context {
  184    std::string name;
  185    vk_device device;
  186};
  187
  188struct vk_queue;
  189
  190// Stores command pool/buffers. There's an instance of this
  191// for each (context,queue) pair and for each (device,queue) pair.
  192struct vk_command_pool {
  193    void init(vk_device& device, vk_queue *q_);
  194    void destroy(vk::Device& device);
  195
  196    vk::CommandPool pool;
  197    uint32_t cmd_buffer_idx;
  198    std::vector<vk::CommandBuffer> cmd_buffers;
  199
  200    vk_queue *q;
  201};
  202
  203// Prevent simultaneous submissions to the same queue.
  204// This could be per vk_queue if we stopped having two vk_queue structures
  205// sharing the same vk::Queue.
  206static std::mutex queue_mutex;
  207
  208struct vk_queue {
  209    uint32_t queue_family_index;
  210    vk::Queue queue;
  211
  212    vk_command_pool cmd_pool;
  213
  214    vk::PipelineStageFlags stage_flags;
  215
  216    bool transfer_only;
  217
  218    // copy everything except the cmd_pool
  219    void copyFrom(vk_queue &other) {
  220        queue_family_index = other.queue_family_index;
  221        queue = other.queue;
  222        stage_flags = other.stage_flags;
  223        transfer_only = other.transfer_only;
  224    }
  225};
  226
  227static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
  228static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
  229static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
  230static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
  231static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
  232static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
  233    /* .get_name         = */ ggml_backend_vk_buffer_type_name,
  234    /* .alloc_buffer     = */ ggml_backend_vk_buffer_type_alloc_buffer,
  235    /* .get_alignment    = */ ggml_backend_vk_buffer_type_get_alignment,
  236    /* .get_max_size     = */ ggml_backend_vk_buffer_type_get_max_size,
  237    /* .get_alloc_size   = */ ggml_backend_vk_buffer_type_get_alloc_size,
  238    /* .is_host          = */ NULL,
  239};
  240
  241class vk_memory_logger;
  242class vk_perf_logger;
  243static void ggml_vk_destroy_buffer(vk_buffer& buf);
  244static void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
  245
  246static constexpr uint32_t mul_mat_vec_max_cols = 8;
  247static constexpr uint32_t p021_max_gqa_ratio = 8;
  248
  249enum vk_device_architecture {
  250    OTHER,
  251    AMD_GCN,
  252    AMD_RDNA1,
  253    AMD_RDNA2,
  254    AMD_RDNA3,
  255    INTEL_XE2,
  256    NVIDIA_PRE_TURING,
  257    NVIDIA_TURING,
  258};
  259
  260static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
  261    vk::PhysicalDeviceProperties props = device.getProperties();
  262
  263    if (props.vendorID == VK_VENDOR_ID_AMD) {
  264        const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
  265
  266        bool amd_shader_core_properties = false;
  267        bool integer_dot_product = false;
  268        bool subgroup_size_control = false;
  269
  270        for (const auto& properties : ext_props) {
  271            if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
  272                amd_shader_core_properties = true;
  273            } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
  274                integer_dot_product = true;
  275            } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
  276                subgroup_size_control = true;
  277            }
  278        }
  279
  280        if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
  281            return vk_device_architecture::OTHER;
  282        }
  283
  284        vk::PhysicalDeviceProperties2 props2;
  285        vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
  286        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
  287        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
  288
  289        props2.pNext = &shader_core_props_amd;
  290        shader_core_props_amd.pNext = &integer_dot_props;
  291        integer_dot_props.pNext = &subgroup_size_control_props;
  292
  293        device.getProperties2(&props2);
  294
  295        if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
  296            return vk_device_architecture::AMD_GCN;
  297        }
  298        if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
  299            // RDNA
  300            if (shader_core_props_amd.wavefrontsPerSimd == 20) {
  301                return vk_device_architecture::AMD_RDNA1;
  302            }
  303            if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
  304                return vk_device_architecture::AMD_RDNA3;
  305            }
  306            return vk_device_architecture::AMD_RDNA2;
  307        }
  308    } else if (props.vendorID == VK_VENDOR_ID_INTEL) {
  309        const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
  310
  311        bool subgroup_size_control = false;
  312
  313        for (const auto& properties : ext_props) {
  314            if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
  315                subgroup_size_control = true;
  316            }
  317        }
  318
  319        if (!subgroup_size_control) {
  320            return vk_device_architecture::OTHER;
  321        }
  322
  323        vk::PhysicalDeviceProperties2 props2;
  324        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
  325
  326        props2.pNext = &subgroup_size_control_props;
  327        device.getProperties2(&props2);
  328
  329        if (subgroup_size_control_props.minSubgroupSize == 16) {
  330            // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
  331            // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
  332            // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
  333            // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
  334            return vk_device_architecture::INTEL_XE2;
  335        }
  336    } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
  337        const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
  338
  339        bool cooperative_matrix = false;
  340        bool sm_builtins = false;
  341
  342        // Detect "pre-turing" based on lack of coopmat support.
  343        for (const auto& properties : ext_props) {
  344            if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
  345                cooperative_matrix = true;
  346            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
  347                sm_builtins = true;
  348            }
  349        }
  350
  351        if (!cooperative_matrix) {
  352            return vk_device_architecture::NVIDIA_PRE_TURING;
  353        }
  354
  355        if (sm_builtins) {
  356            vk::PhysicalDeviceProperties2 props2;
  357            vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
  358
  359            props2.pNext = &sm_props;
  360
  361            device.getProperties2(&props2);
  362
  363            // Turing has 32, following architectures have 48
  364            if (sm_props.shaderWarpsPerSM == 32) {
  365                return vk_device_architecture::NVIDIA_TURING;
  366            }
  367        }
  368    }
  369    return vk_device_architecture::OTHER;
  370}
  371
  372enum vk_conv_shapes {
  373    CONV_SHAPE_128x128,
  374    CONV_SHAPE_64x32,
  375    CONV_SHAPE_32x256,
  376    CONV_SHAPE_COUNT,
  377};
  378
  379struct vk_conv_block_size {
  380    uint32_t K;
  381    uint32_t NPQ;
  382    uint32_t CRS;
  383};
  384
  385vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
  386    // K   NPQ  CRS
  387    { 128, 128, 16 }, // CONV_SHAPE_128x128
  388    {  64,  32, 32 }, // CONV_SHAPE_64x32
  389    {  32, 256, 16 }, // CONV_SHAPE_32x256
  390};
  391
  392enum dmmv_wg_sizes {
  393    DMMV_WG_SIZE_SUBGROUP,
  394    DMMV_WG_SIZE_LARGE,
  395    DMMV_WG_SIZE_COUNT,
  396};
  397
  398enum FaCodePath {
  399    FA_SCALAR,
  400    FA_COOPMAT1,
  401    FA_COOPMAT2,
  402};
  403
  404struct vk_fa_pipeline_state {
  405    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
  406        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
  407
  408    uint32_t HSK, HSV;
  409    bool small_rows, small_cache;
  410    FaCodePath path;
  411    bool aligned;
  412    bool f32acc;
  413    uint32_t flags;
  414
  415    bool operator<(const vk_fa_pipeline_state &b) const {
  416        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
  417               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
  418    }
  419};
  420
  421struct vk_conv2d_pipeline_state {
  422    vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
  423        : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
  424
  425    uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
  426
  427    bool operator<(const vk_conv2d_pipeline_state &b) const {
  428        return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
  429               std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
  430    }
  431};
  432
  433struct vk_solve_tri_pipeline_state {
  434    vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
  435        : N(N), K(K) {}
  436
  437    uint32_t N, K;
  438
  439    bool operator<(const vk_solve_tri_pipeline_state &b) const {
  440        return std::tie(N, K) <
  441               std::tie(b.N, b.K);
  442    }
  443};
  444
  445enum shader_reduction_mode {
  446    SHADER_REDUCTION_MODE_SHMEM,
  447    SHADER_REDUCTION_MODE_HYBRID,
  448    SHADER_REDUCTION_MODE_SUBGROUP,
  449    SHADER_REDUCTION_MODE_COUNT,
  450};
  451
  452// argsort pipelines for up to 1<<10 invocations per workgroup
  453static constexpr uint32_t num_argsort_pipelines = 11;
  454static constexpr uint32_t num_topk_moe_pipelines = 10;
  455static constexpr uint32_t num_topk_pipelines = 11;
  456
  457static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
  458                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
  459                                                                             GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
  460                                                                             GGML_OP_RESHAPE };
  461
  462static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY,    GGML_OP_RESHAPE,  GGML_OP_ADD,
  463                                                                            GGML_OP_ARGSORT,  GGML_OP_VIEW,     GGML_OP_GET_ROWS,
  464                                                                            GGML_OP_RESHAPE,  GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
  465                                                                            GGML_OP_DIV,      GGML_OP_RESHAPE };
  466
  467static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
  468                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS };
  469
  470static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax      { GGML_OP_ARGSORT,  GGML_OP_VIEW,
  471                                                                             GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
  472                                                                             GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
  473
  474//node #978 (  SOFT_MAX):     ffn_moe_probs-15 (   0K) [Vulka         ] use=2:    ffn_moe_logits-15 (   0K) [Vulka         ]
  475//node #979 (   RESHAPE): ffn_moe_probs-15 (re (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]
  476//node #980 (   ARGSORT):   ffn_moe_argsort-15 (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]
  477//node #981 (      VIEW):      ffn_moe_topk-15 (   0K) [Vulka         ] use=4:   ffn_moe_argsort-15 (   0K) [Vulka         ]
  478//node #982 (  GET_ROWS):   ffn_moe_weights-15 (   0K) [Vulka         ] use=1: ffn_moe_probs-15 (re (   0K) [Vulka         ]      ffn_moe_topk-15 (   0K) [Vulka         ]
  479//node #983 (   RESHAPE): ffn_moe_weights-15 ( (   0K) [Vulka         ] use=2:   ffn_moe_weights-15 (   0K) [Vulka         ]
  480//node #984 (  SUM_ROWS): ffn_moe_weights_sum- (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ]
  481//node #985 (     CLAMP): ffn_moe_weights_sum_ (   0K) [Vulka         ] use=1: ffn_moe_weights_sum- (   0K) [Vulka         ]
  482//node #986 (       DIV): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ] ffn_moe_weights_sum_ (   0K) [Vulka         ]
  483//node #987 (   RESHAPE): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights_norm (   0K) [Vulka         ]
  484static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
  485    { 1, 0, 0 }, // reshape->src[0]  == softmax
  486    { 2, 0, 0 }, // argsort->src[0]  == softmax
  487    { 3, 0, 2 }, // view->src[0]     == argsort
  488    { 4, 0, 1 }, // get_rows->src[0] == reshape
  489    { 4, 1, 3 }, // get_rows->src[1] == view
  490    { 5, 0, 4 }, // reshape->src[0]  == get_rows
  491    { 6, 0, 5 }, // sum_rows->src[0] == reshape
  492    { 7, 0, 6 }, // clamp->src[0]    == sum_rows
  493    { 8, 0, 5 }, // div->src[0]      == reshape
  494    { 8, 1, 7 }, // div->src[1]      == clamp
  495    { 9, 0, 8 }, // reshape->src[0]  == div
  496};
  497
  498//node #436 (     UNARY):     ffn_moe_probs-10 ( 256K) [Vulka         ] use=2:    ffn_moe_logits-10 ( 256K) [Vulka         ]
  499//node #437 (   RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ]
  500//node #438 (       ADD): ffn_moe_probs_biased ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ] blk.10.exp_probs_b.b (   0K) [Vulka         ]
  501//node #439 (   ARGSORT):   ffn_moe_argsort-10 ( 256K) [Vulka         ] use=1: ffn_moe_probs_biased ( 256K) [Vulka         ]
  502//node #440 (      VIEW):      ffn_moe_topk-10 ( 255K) [Vulka         ] use=3:   ffn_moe_argsort-10 ( 256K) [Vulka         ]
  503//node #441 (  GET_ROWS):   ffn_moe_weights-10 (  12K) [Vulka         ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka         ]      ffn_moe_topk-10 ( 255K) [Vulka         ]
  504//node #442 (   RESHAPE): ffn_moe_weights-10 ( (  12K) [Vulka         ] use=2:   ffn_moe_weights-10 (  12K) [Vulka         ]
  505//node #443 (  SUM_ROWS): ffn_moe_weights_sum- (   2K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ]
  506//node #444 (     CLAMP): ffn_moe_weights_sum_ (   2K) [Vulka         ] use=1: ffn_moe_weights_sum- (   2K) [Vulka         ]
  507//node #445 (       DIV): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ] ffn_moe_weights_sum_ (   2K) [Vulka         ]
  508//node #446 (   RESHAPE): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights_norm (  12K) [Vulka         ]
  509static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
  510    { 1, 0, 0 }, // reshape->src[0]  == sigmoid
  511    { 2, 0, 0 }, // add->src[0]      == sigmoid
  512    { 3, 0, 2 }, // argsort->src[0]  == add
  513    { 4, 0, 3 }, // view->src[0]     == argsort
  514    { 5, 0, 1 }, // get_rows->src[0] == reshape
  515    { 5, 1, 4 }, // get_rows->src[1] == view
  516    { 6, 0, 5 }, // reshape->src[0]  == get_rows
  517    { 7, 0, 6 }, // sum_rows->src[0] == reshape
  518    { 8, 0, 7 }, // clamp->src[0]    == sum_rows
  519    { 9, 0, 6 }, // div->src[0]      == reshape
  520    { 9, 1, 8 }, // div->src[1]      == clamp
  521    {10, 0, 9 }, // reshape->src[0]  == div
  522};
  523
  524// same as early_softmax_norm but ending after the get_rows
  525static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
  526    { 1, 0, 0 }, // reshape->src[0]  == softmax
  527    { 2, 0, 0 }, // argsort->src[0]  == softmax
  528    { 3, 0, 2 }, // view->src[0]     == argsort
  529    { 4, 0, 1 }, // get_rows->src[0] == reshape
  530    { 4, 1, 3 }, // get_rows->src[1] == view
  531};
  532
  533//node #652 (   ARGSORT):   ffn_moe_argsort-11 (   0K) [Vulka         ] use=1:     ffn_moe_probs-11 (   0K) [Vulka         ]
  534//node #653 (      VIEW):      ffn_moe_topk-11 (   0K) [Vulka         ] use=7:   ffn_moe_argsort-11 (   0K) [Vulka         ]
  535//node #654 (  GET_ROWS):   ffn_moe_weights-11 (   0K) [Vulka         ] use=1: ffn_moe_probs-11 (re (   0K) [Vulka         ]      ffn_moe_topk-11 (   0K) [Vulka         ]
  536//node #655 (   RESHAPE): ffn_moe_weights-11 ( (   0K) [Vulka         ] use=1:   ffn_moe_weights-11 (   0K) [Vulka         ]
  537//node #656 (  SOFT_MAX):             node_656 (   0K) [Vulka         ] use=1: ffn_moe_weights-11 ( (   0K) [Vulka         ]
  538//node #657 (   RESHAPE): ffn_moe_weights_soft (   0K) [Vulka         ] use=1:             node_656 (   0K) [Vulka         ]
  539static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
  540    { 1, 0, 0 }, // view->src[0]     == argsort
  541    { 2, 1, 1 }, // get_rows->src[1] == view
  542    { 3, 0, 2 }, // reshape->src[0]  == get_rows
  543    { 4, 0, 3 }, // soft_max->src[0] == reshape
  544    { 5, 0, 4 }, // reshape->src[0]  == soft_max
  545};
  546
  547enum topk_moe_mode {
  548    TOPK_MOE_EARLY_SOFTMAX,
  549    TOPK_MOE_EARLY_SOFTMAX_NORM,
  550    TOPK_MOE_LATE_SOFTMAX,
  551    TOPK_MOE_SIGMOID_NORM_BIAS,
  552    TOPK_MOE_COUNT,
  553};
  554
  555static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
  556    { 1, 0, 0 }, // view->src[0]     == rope
  557    { 2, 0, 1 }, // set_rows->src[0] == view
  558};
  559
  560static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {
  561    { 1, 0, 0 }, // mul->src[0]      == rms
  562    { 2, 0, 1 }, // rope->src[0]     == mul
  563    { 3, 0, 2 }, // view->src[0]     == rope
  564    { 4, 0, 3 }, // set_rows->src[0] == view
  565};
  566
  567
  568struct vk_device_struct {
  569    std::recursive_mutex mutex;
  570
  571    vk::PhysicalDevice physical_device;
  572    vk::PhysicalDeviceProperties properties;
  573    std::string name;
  574    uint64_t max_memory_allocation_size;
  575    uint64_t max_buffer_size;
  576    uint64_t suballocation_block_size;
  577    uint64_t min_imported_host_pointer_alignment;
  578    bool external_memory_host {};
  579    bool fp16;
  580    bool bf16;
  581    bool pipeline_robustness;
  582    bool memory_priority;
  583    vk::Device device;
  584    uint32_t vendor_id;
  585    vk::DriverId driver_id;
  586    vk_device_architecture architecture;
  587    vk_queue compute_queue;
  588    vk_queue transfer_queue;
  589    bool single_queue;
  590    bool support_async;
  591    uint32_t subgroup_size;
  592    uint32_t subgroup_size_log2;
  593    uint32_t shader_core_count;
  594    bool uma;
  595    bool prefer_host_memory;
  596    bool float_controls_rte_fp16;
  597    bool subgroup_basic;
  598    bool subgroup_arithmetic;
  599    bool subgroup_shuffle;
  600    bool subgroup_ballot;
  601    bool subgroup_clustered;
  602    bool subgroup_vote;
  603    bool multi_add;
  604    bool shader_int64;
  605    bool buffer_device_address;
  606    bool vulkan_memory_model;
  607
  608    bool add_rms_fusion;
  609    uint32_t partials_binding_alignment;
  610
  611    bool shader_64b_indexing;
  612
  613    bool integer_dot_product;
  614    // 0: default, 1: force mmvq, -1: disable mmvq
  615    int32_t mmvq_mode;
  616
  617    bool subgroup_size_control;
  618    uint32_t subgroup_min_size;
  619    uint32_t subgroup_max_size;
  620    bool subgroup_require_full_support;
  621
  622    // floor(log2(maxComputeWorkGroupInvocations))
  623    uint32_t max_workgroup_size_log2 {};
  624
  625    bool coopmat_support;
  626    bool coopmat_acc_f32_support {};
  627    bool coopmat_acc_f16_support {};
  628    bool coopmat_bf16_support {};
  629    bool coopmat_support_16x16x16_f16acc {};
  630    bool coopmat_support_16x16x16_f32acc {};
  631    bool coopmat1_fa_support {};
  632    uint32_t coopmat_m;
  633    uint32_t coopmat_n;
  634    uint32_t coopmat_k;
  635
  636    bool coopmat_int_support;
  637    uint32_t coopmat_int_m;
  638    uint32_t coopmat_int_n;
  639    uint32_t coopmat_int_k;
  640
  641    bool coopmat2;
  642
  643    bool pipeline_executable_properties_support {};
  644
  645    size_t idx;
  646
  647    bool mul_mat_l[GGML_TYPE_COUNT];
  648    bool mul_mat_m[GGML_TYPE_COUNT];
  649    bool mul_mat_s[GGML_TYPE_COUNT];
  650    bool mul_mat_id_l[GGML_TYPE_COUNT];
  651    bool mul_mat_id_m[GGML_TYPE_COUNT];
  652    bool mul_mat_id_s[GGML_TYPE_COUNT];
  653
  654    vk::DescriptorSetLayout dsl;
  655
  656    vk_matmul_pipeline pipeline_matmul_f32 {};
  657    vk_matmul_pipeline pipeline_matmul_f32_f16 {};
  658    vk_matmul_pipeline pipeline_matmul_bf16 {};
  659    vk_matmul_pipeline2 pipeline_matmul_f16;
  660    vk_matmul_pipeline2 pipeline_matmul_f16_f32;
  661
  662    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
  663    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
  664    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
  665
  666    vk_matmul_pipeline pipeline_matmul_id_f32 {};
  667    vk_matmul_pipeline pipeline_matmul_id_bf16 {};
  668    vk_matmul_pipeline2 pipeline_matmul_id_f16;
  669    vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
  670
  671    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
  672    vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
  673
  674    vk_pipeline pipeline_matmul_split_k_reduce;
  675    vk_pipeline pipeline_quantize_q8_1_x4;
  676
  677    vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
  678    vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
  679    vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
  680    vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
  681
  682    vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
  683    vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
  684
  685    vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
  686    vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
  687    vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
  688    vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
  689    vk_pipeline pipeline_acc_f32;
  690
  691    // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
  692    vk_pipeline pipeline_add[2][2][2];
  693    vk_pipeline pipeline_add_norepeat[2][2][2];
  694    vk_pipeline pipeline_sub[2][2][2];
  695    vk_pipeline pipeline_sub_norepeat[2][2][2];
  696    vk_pipeline pipeline_mul[2][2][2];
  697    vk_pipeline pipeline_mul_norepeat[2][2][2];
  698    vk_pipeline pipeline_div[2][2][2];
  699    vk_pipeline pipeline_div_norepeat[2][2][2];
  700    vk_pipeline pipeline_add_rms[2][2][2];
  701    vk_pipeline pipeline_add_rms_norepeat[2][2][2];
  702
  703    // indexed by num_additional_fused_ops == num_adds - 1
  704    vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
  705    vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
  706
  707    vk_pipeline pipeline_add_id_f32;
  708
  709    vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
  710    vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
  711    vk_pipeline pipeline_scale_f32;
  712    vk_pipeline pipeline_sqr_f32;
  713    vk_pipeline pipeline_sqrt_f32;
  714    vk_pipeline pipeline_sin_f32;
  715    vk_pipeline pipeline_cos_f32;
  716    vk_pipeline pipeline_log[2];
  717    vk_pipeline pipeline_tri[2];
  718    vk_pipeline pipeline_diag[2];
  719    vk_pipeline pipeline_clamp_f32;
  720    vk_pipeline pipeline_pad_f32;
  721    vk_pipeline pipeline_roll_f32;
  722    vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
  723    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
  724    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
  725    vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
  726    vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
  727    vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
  728    vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
  729    vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
  730    vk_pipeline pipeline_norm_f32;
  731    vk_pipeline pipeline_group_norm_f32;
  732    vk_pipeline pipeline_rms_norm_f32;
  733    vk_pipeline pipeline_rms_norm_mul_f32;
  734    vk_pipeline pipeline_rms_norm_partials_f32;
  735    vk_pipeline pipeline_rms_norm_mul_partials_f32;
  736    vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
  737    vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
  738    vk_pipeline pipeline_rms_norm_back_f32;
  739    vk_pipeline pipeline_l2_norm_f32;
  740
  741    // [src/dst 0=fp32,1=fp16]
  742    vk_pipeline pipeline_exp[2];
  743    vk_pipeline pipeline_gelu[2];
  744    vk_pipeline pipeline_gelu_erf[2];
  745    vk_pipeline pipeline_gelu_quick[2];
  746    vk_pipeline pipeline_silu[2];
  747    vk_pipeline pipeline_relu[2];
  748    vk_pipeline pipeline_xielu[2];
  749    vk_pipeline pipeline_neg[2];
  750    vk_pipeline pipeline_tanh[2];
  751    vk_pipeline pipeline_sigmoid[2];
  752    vk_pipeline pipeline_hardsigmoid[2];
  753    vk_pipeline pipeline_hardswish[2];
  754    vk_pipeline pipeline_abs[2];
  755    vk_pipeline pipeline_softplus[2];
  756    vk_pipeline pipeline_step[2];
  757    vk_pipeline pipeline_round[2];
  758    vk_pipeline pipeline_ceil[2];
  759    vk_pipeline pipeline_floor[2];
  760    vk_pipeline pipeline_trunc[2];
  761
  762    vk_pipeline pipeline_add1_f16_f16;
  763    vk_pipeline pipeline_add1_f16_f32;
  764    vk_pipeline pipeline_add1_f32_f32;
  765
  766    vk_pipeline pipeline_arange_f32;
  767
  768    vk_pipeline pipeline_fill_f32;
  769
  770    vk_pipeline pipeline_geglu[2];
  771    vk_pipeline pipeline_reglu[2];
  772    vk_pipeline pipeline_swiglu[2];
  773    vk_pipeline pipeline_swiglu_oai[2];
  774    vk_pipeline pipeline_geglu_erf[2];
  775    vk_pipeline pipeline_geglu_quick[2];
  776
  777    vk_pipeline pipeline_leaky_relu_f32;
  778    vk_pipeline pipeline_silu_back_f32;
  779    vk_pipeline pipeline_diag_mask_inf_f32;
  780    vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
  781    vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
  782    vk_pipeline pipeline_soft_max_back_f32;
  783
  784    vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
  785    vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
  786    vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
  787
  788    vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
  789    vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
  790    vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
  791    vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
  792    vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
  793    vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
  794    vk_pipeline pipeline_topk_f32[num_topk_pipelines];
  795    vk_pipeline pipeline_sum_rows_f32;
  796    vk_pipeline pipeline_cumsum_f32;
  797    vk_pipeline pipeline_cumsum_small_f32;
  798    vk_pipeline pipeline_cumsum_multipass1_f32;
  799    vk_pipeline pipeline_cumsum_multipass2_f32;
  800    vk_pipeline pipeline_argmax_f32;
  801    vk_pipeline pipeline_count_equal_i32;
  802    std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
  803    vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
  804    vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
  805    vk_pipeline pipeline_timestep_embedding_f32;
  806    vk_pipeline pipeline_conv_transpose_1d_f32;
  807    vk_pipeline pipeline_pool2d_f32;
  808    vk_pipeline pipeline_rwkv_wkv6_f32;
  809    vk_pipeline pipeline_rwkv_wkv7_f32;
  810    vk_pipeline pipeline_ssm_scan_f32_d128;
  811    vk_pipeline pipeline_ssm_scan_f32_d256;
  812    vk_pipeline pipeline_ssm_conv_f32;
  813    vk_pipeline pipeline_opt_step_adamw_f32;
  814    vk_pipeline pipeline_opt_step_sgd_f32;
  815    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
  816    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
  817    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
  818    std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
  819    vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
  820    vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
  821
  822    std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
  823
  824    std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
  825
  826    vk_pipeline pipeline_flash_attn_split_k_reduce;
  827    vk_pipeline pipeline_count_experts;
  828
  829    // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
  830    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
  831
  832    std::vector<vk_pipeline_ref> all_pipelines;
  833
  834    std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
  835
  836    vk::Fence fence;
  837    vk_buffer sync_staging;
  838
  839    ggml_backend_buffer_type buffer_type;
  840
  841    bool disable_fusion;
  842    bool disable_host_visible_vidmem;
  843    bool allow_sysmem_fallback;
  844    bool disable_graph_optimize;
  845
  846    std::unique_ptr<vk_memory_logger> memory_logger;
  847
  848    ~vk_device_struct() {
  849        VK_LOG_DEBUG("destroy device " << name);
  850
  851        device.destroyFence(fence);
  852
  853        ggml_vk_destroy_buffer(sync_staging);
  854
  855        compute_queue.cmd_pool.destroy(device);
  856        transfer_queue.cmd_pool.destroy(device);
  857
  858        for (auto& pipeline : all_pipelines) {
  859            if (pipeline.expired()) {
  860                continue;
  861            }
  862
  863            vk_pipeline pl = pipeline.lock();
  864            ggml_vk_destroy_pipeline(device, pl);
  865        }
  866        all_pipelines.clear();
  867
  868        device.destroyDescriptorSetLayout(dsl);
  869
  870        device.destroy();
  871    }
  872};
  873
  874void vk_command_pool::init(vk_device& device, vk_queue *q_) {
  875    cmd_buffer_idx = 0;
  876    q = q_;
  877
  878    vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
  879    pool = device->device.createCommandPool(command_pool_create_info);
  880}
  881
  882void vk_command_pool::destroy(vk::Device& device) {
  883    device.destroyCommandPool(pool);
  884    pool = nullptr;
  885    cmd_buffers.clear();
  886}
  887
  888struct vk_buffer_struct {
  889    vk::Buffer buffer = VK_NULL_HANDLE;
  890    vk::DeviceMemory device_memory = VK_NULL_HANDLE;
  891    vk::MemoryPropertyFlags memory_property_flags;
  892    void * ptr;
  893    size_t size = 0;
  894    vk::DeviceAddress bda_addr {};
  895
  896    vk_device device;
  897
  898    ~vk_buffer_struct() {
  899        if (size == 0) {
  900            return;
  901        }
  902        VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
  903
  904        device->device.freeMemory(device_memory);
  905        device->device.destroyBuffer(buffer);
  906    }
  907};
  908
  909struct vk_subbuffer {
  910    vk_buffer buffer;
  911    uint64_t offset;
  912    uint64_t size;
  913
  914    operator vk::DescriptorBufferInfo() const {
  915        return { buffer->buffer, offset, size };
  916    }
  917};
  918
  919// vk_event is used for the event-related backend interfaces. It uses 'event' for
  920// event_wait and 'fence' for event_synchronize. Polling on an event for
  921// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
  922// and would lead to validation errors.
  923struct vk_event {
  924    vk::Event event;
  925    vk::Fence fence;
  926};
  927
  928struct vk_semaphore {
  929    vk::Semaphore s;
  930    uint64_t value;
  931};
  932
  933struct vk_submission {
  934    vk::CommandBuffer buffer;
  935    std::vector<vk_semaphore> wait_semaphores;
  936    std::vector<vk_semaphore> signal_semaphores;
  937};
  938
  939typedef std::vector<vk_submission> vk_sequence;
  940
  941struct vk_mat_mat_push_constants {
  942    uint32_t M; uint32_t N; uint32_t K;
  943    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
  944    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
  945    uint32_t k_split;
  946    uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
  947    uint32_t padded_N;
  948};
  949
  950#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
  951#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
  952#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
  953#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
  954
  955struct vk_mat_vec_push_constants {
  956    uint32_t ncols;
  957    uint32_t stride_a;
  958    uint32_t stride_b;
  959    uint32_t stride_d;
  960    uint32_t batch_stride_a;
  961    uint32_t batch_stride_b;
  962    uint32_t batch_stride_d;
  963    uint32_t fusion_flags;
  964    uint32_t ne02;
  965    uint32_t ne12;
  966    uint32_t broadcast2;
  967    uint32_t broadcast3;
  968};
  969
  970struct vk_mat_vec_p021_push_constants {
  971    uint32_t ncols_x;
  972    uint32_t nrows_x;
  973    uint32_t nchannels_x;
  974    uint32_t nchannels_y;
  975    uint32_t b_offset;
  976    uint32_t d_offset;
  977    uint32_t fusion_flags;
  978};
  979
  980struct vk_mat_vec_nc_push_constants {
  981    uint32_t ncols_x;
  982    uint32_t nrows_x;
  983    uint32_t row_stride_x;
  984    uint32_t channel_stride_x;
  985    uint32_t channel_stride_y;
  986    uint32_t channel_x_divisor;
  987    uint32_t ne12;
  988    uint32_t b_offset;
  989    uint32_t d_offset;
  990    uint32_t nb03;
  991    uint32_t nb13;
  992    uint32_t nb23;
  993    uint32_t fusion_flags;
  994};
  995
  996struct vk_mat_mat_id_push_constants {
  997    uint32_t M; uint32_t N; uint32_t K;
  998    uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
  999    uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
 1000    uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
 1001    uint32_t padded_N;
 1002};
 1003struct vk_mat_vec_id_push_constants {
 1004    uint32_t ncols;
 1005    uint32_t stride_a;
 1006    uint32_t stride_b;
 1007    uint32_t stride_d;
 1008    uint32_t batch_stride_a;
 1009    uint32_t batch_stride_b;
 1010    uint32_t batch_stride_d;
 1011    uint32_t fusion_flags;
 1012    uint32_t nei0;
 1013    uint32_t ne11;
 1014    uint32_t expert_i1;
 1015    uint32_t nbi1;
 1016};
 1017
 1018struct vk_flash_attn_push_constants {
 1019    uint32_t N;
 1020    uint32_t KV;
 1021
 1022    uint32_t ne1;
 1023    uint32_t ne2;
 1024    uint32_t ne3;
 1025
 1026    uint32_t neq2;
 1027    uint32_t neq3;
 1028    uint32_t nek2;
 1029    uint32_t nek3;
 1030    uint32_t nev2;
 1031    uint32_t nev3;
 1032    uint32_t nem1;
 1033    uint32_t nem2;
 1034    uint32_t nem3;
 1035
 1036    uint32_t nb01;
 1037    uint32_t nb02;
 1038    uint32_t nb03;
 1039    uint32_t nb11;
 1040    uint32_t nb12;
 1041    uint32_t nb13;
 1042    uint32_t nb21;
 1043    uint32_t nb22;
 1044    uint32_t nb23;
 1045
 1046    float scale;
 1047    float max_bias;
 1048    float logit_softcap;
 1049
 1050    uint32_t mask_n_head_log2;
 1051    float m0;
 1052    float m1;
 1053
 1054    uint32_t gqa_ratio;
 1055    uint32_t split_kv;
 1056    uint32_t k_num;
 1057};
 1058static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
 1059
 1060struct vk_op_push_constants {
 1061    uint32_t KX;
 1062    uint32_t KY;
 1063    float param1;
 1064    float param2;
 1065    float param3;
 1066    float param4;
 1067};
 1068
 1069struct vk_op_count_experts_push_constants {
 1070    uint32_t ne00;
 1071    uint32_t ne01;
 1072    uint32_t nb00;
 1073    uint32_t nb01;
 1074    uint32_t a_offset;
 1075};
 1076
 1077struct vk_op_glu_push_constants {
 1078    uint32_t N;
 1079    uint32_t ne00;
 1080    uint32_t ne20;
 1081    uint32_t mode;  // 0: default, 1: swapped, 2: split
 1082    float alpha; // for swiglu_oai
 1083    float limit;
 1084};
 1085
 1086struct vk_op_unary_push_constants {
 1087    uint32_t ne;
 1088    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
 1089    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
 1090    uint32_t misalign_offsets;
 1091    float param1; float param2;
 1092    uint32_t ne0_012mp; uint32_t ne0_012L;
 1093    uint32_t ne0_01mp;  uint32_t ne0_01L;
 1094    uint32_t ne0_0mp;   uint32_t ne0_0L;
 1095    uint32_t ne1_012mp; uint32_t ne1_012L;
 1096    uint32_t ne1_01mp;  uint32_t ne1_01L;
 1097    uint32_t ne1_0mp;   uint32_t ne1_0L;
 1098};
 1099static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
 1100
 1101static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
 1102    GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
 1103    ne = ne != 0 ? ne : ggml_nelements(dst);
 1104    GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
 1105
 1106    vk_op_unary_push_constants p{};
 1107    p.ne = (uint32_t)ne;
 1108
 1109    size_t src0_tsize = ggml_type_size(src0->type);
 1110    p.ne00 = (uint32_t)src0->ne[0];
 1111    p.ne01 = (uint32_t)src0->ne[1];
 1112    p.ne02 = (uint32_t)src0->ne[2];
 1113    p.ne03 = (uint32_t)src0->ne[3];
 1114    p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
 1115    p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
 1116    p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
 1117    p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
 1118
 1119    size_t dst_tsize = ggml_type_size(dst->type);
 1120    p.ne10 = (uint32_t)dst->ne[0];
 1121    p.ne11 = (uint32_t)dst->ne[1];
 1122    p.ne12 = (uint32_t)dst->ne[2];
 1123    p.ne13 = (uint32_t)dst->ne[3];
 1124    p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
 1125    p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
 1126    p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
 1127    p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
 1128
 1129    return p; // offsets are initialized later in ggml_vk_op
 1130}
 1131
 1132struct vk_op_pad_push_constants {
 1133    uint32_t ne;
 1134    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
 1135    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
 1136    uint32_t misalign_offsets;
 1137    uint32_t circular;
 1138
 1139    uint32_t lp0; uint32_t rp0;
 1140    uint32_t lp1; uint32_t rp1;
 1141    uint32_t lp2; uint32_t rp2;
 1142    uint32_t lp3; uint32_t rp3;
 1143};
 1144
 1145static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
 1146    int64_t ne = ggml_nelements(dst);
 1147    GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
 1148
 1149    vk_op_pad_push_constants p{};
 1150    p.ne = (uint32_t)ne;
 1151
 1152    size_t src0_tsize = ggml_type_size(src0->type);
 1153    p.ne00 = (uint32_t)src0->ne[0];
 1154    p.ne01 = (uint32_t)src0->ne[1];
 1155    p.ne02 = (uint32_t)src0->ne[2];
 1156    p.ne03 = (uint32_t)src0->ne[3];
 1157    p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
 1158    p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
 1159    p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
 1160    p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
 1161
 1162    size_t dst_tsize = ggml_type_size(dst->type);
 1163    p.ne10 = (uint32_t)dst->ne[0];
 1164    p.ne11 = (uint32_t)dst->ne[1];
 1165    p.ne12 = (uint32_t)dst->ne[2];
 1166    p.ne13 = (uint32_t)dst->ne[3];
 1167    p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
 1168    p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
 1169    p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
 1170    p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
 1171
 1172    p.lp0 = dst->op_params[0];
 1173    p.rp0 = dst->op_params[1];
 1174    p.lp1 = dst->op_params[2];
 1175    p.rp1 = dst->op_params[3];
 1176    p.lp2 = dst->op_params[4];
 1177    p.rp2 = dst->op_params[5];
 1178    p.lp3 = dst->op_params[6];
 1179    p.rp3 = dst->op_params[7];
 1180    p.circular = dst->op_params[8];
 1181
 1182    return p; // fastdiv values and offsets are initialized later in ggml_vk_op
 1183}
 1184
 1185// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
 1186// Precompute mp (m' in the paper) and L such that division
 1187// can be computed using a multiply (high 32b of 64b result)
 1188// and a shift:
 1189//
 1190// n/d = (mulhi(n, mp) + n) >> L;
 1191static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
 1192{
 1193    // compute L = ceil(log2(d));
 1194    L = 0;
 1195    while (L < 32 && (uint32_t{1} << L) < d) {
 1196        L++;
 1197    }
 1198
 1199    mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
 1200}
 1201
 1202template <typename T> void init_pushconst_fastdiv(T &p) {
 1203    GGML_UNUSED(p);
 1204    static_assert(!std::is_const<T>::value, "unexpected type");
 1205}
 1206
 1207template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
 1208    // Compute magic values to divide by these six numbers.
 1209    init_fastdiv_values(p.ne02*p.ne01*p.ne00,  p.ne0_012mp,    p.ne0_012L);
 1210    init_fastdiv_values(p.ne01*p.ne00,         p.ne0_01mp,     p.ne0_01L);
 1211    init_fastdiv_values(p.ne00,                p.ne0_0mp,      p.ne0_0L);
 1212    init_fastdiv_values(p.ne12*p.ne11*p.ne10,  p.ne1_012mp,    p.ne1_012L);
 1213    init_fastdiv_values(p.ne11*p.ne10,         p.ne1_01mp,     p.ne1_01L);
 1214    init_fastdiv_values(p.ne10,                p.ne1_0mp,      p.ne1_0L);
 1215}
 1216
 1217struct vk_op_binary_push_constants {
 1218    uint32_t ne;
 1219    uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
 1220    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
 1221    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
 1222    uint32_t misalign_offsets;
 1223    float param1; float param2; int32_t param3;
 1224};
 1225
 1226struct vk_op_multi_add_push_constants {
 1227    // shape for dst
 1228    uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
 1229
 1230    // strides for srcs+dst
 1231    uint32_t nb[MAX_PARAMETER_COUNT][4];
 1232
 1233    uint32_t rms_partials;
 1234};
 1235// update multi_add.comp if this changes
 1236static_assert(MAX_PARAMETER_COUNT == 12);
 1237static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
 1238
 1239struct vk_op_topk_moe_push_constants {
 1240    uint32_t n_rows;
 1241    uint32_t n_experts_push;
 1242    uint32_t n_expert_used;
 1243    float clamp_min;
 1244    float clamp_max;
 1245    uint32_t gating_func;
 1246    uint32_t has_bias;
 1247    uint32_t with_norm;
 1248    float output_scale;
 1249    float output_bias;
 1250};
 1251
 1252struct vk_op_add_id_push_constants {
 1253    uint32_t ne0;
 1254    uint32_t ne1;
 1255    uint32_t s01;
 1256    uint32_t s02;
 1257    uint32_t s11;
 1258    uint32_t s21;
 1259};
 1260
 1261struct vk_op_diag_mask_push_constants {
 1262    uint32_t ncols;
 1263    uint32_t rows_per_channel;
 1264    int32_t n_past;
 1265};
 1266
 1267struct vk_op_rope_push_constants {
 1268    uint32_t rope_mode;
 1269    uint32_t nrows;
 1270    uint32_t n_dims;
 1271    float freq_scale;
 1272    float freq_base;
 1273    float ext_factor;
 1274    float attn_factor;
 1275    float corr_dims[2];
 1276    float theta_scale;
 1277    uint32_t has_ff;
 1278    int32_t sections[4];
 1279    uint32_t is_imrope;
 1280    uint32_t is_back;
 1281    uint32_t set_rows_stride;
 1282    uint32_t ne00;
 1283    uint32_t ne01;
 1284    uint32_t ne02;
 1285    uint32_t nb01;
 1286    uint32_t nb02;
 1287    uint32_t nb03;
 1288    uint32_t nb11;
 1289    uint32_t nb12;
 1290    uint32_t nb13;
 1291};
 1292static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
 1293
 1294// For fused rms_norm+mul+rope(+view+set_rows)
 1295struct vk_op_rms_norm_mul_rope_push_constants {
 1296    vk_op_binary_push_constants bin;
 1297    vk_op_rope_push_constants rope;
 1298};
 1299
 1300struct vk_op_soft_max_push_constants {
 1301    uint32_t KX;
 1302    uint32_t KY;
 1303    uint32_t ne00;
 1304    uint32_t ne01;
 1305    uint32_t ne02;
 1306    uint32_t ne12;
 1307    uint32_t ne13;
 1308    uint32_t nb11;
 1309    uint32_t nb12;
 1310    uint32_t nb13;
 1311    float scale;
 1312    float max_bias;
 1313    float m0;
 1314    float m1;
 1315    uint32_t n_head_log2;
 1316    uint32_t nrows_x;
 1317    uint32_t has_sinks;
 1318};
 1319
 1320struct vk_op_argsort_push_constants {
 1321    uint32_t ncols;
 1322    uint32_t ncols_padded;
 1323    uint32_t ncols_padded_log2;
 1324    uint32_t nrows;
 1325    uint32_t order;
 1326    uint32_t outer_start;
 1327    uint32_t outer_end;
 1328    uint32_t inner_start;
 1329    uint32_t inner_end;
 1330};
 1331
 1332struct vk_op_topk_push_constants {
 1333    uint32_t orig_ncols;
 1334    uint32_t ncols_input;
 1335    uint32_t ncols_output;
 1336    uint32_t k;
 1337    uint32_t nrows;
 1338    uint32_t first_pass;
 1339    uint32_t last_pass;
 1340};
 1341
 1342struct vk_op_im2col_push_constants {
 1343    uint64_t dst_addr;
 1344    uint32_t batch_offset; uint32_t offset_delta;
 1345    uint32_t IC;
 1346    uint32_t IW; uint32_t IH;
 1347    uint32_t OW; uint32_t OH;
 1348    uint32_t KW; uint32_t KH;
 1349    uint32_t pelements;
 1350    uint32_t CHW;
 1351    int32_t s0; int32_t s1;
 1352    int32_t p0; int32_t p1;
 1353    int32_t d0; int32_t d1;
 1354    uint32_t batch_IC;
 1355};
 1356
 1357struct vk_op_im2col_3d_push_constants {
 1358    uint64_t dst_addr;
 1359    uint32_t nb10;
 1360    uint32_t nb11;
 1361    uint32_t nb12;
 1362    uint32_t nb13;
 1363    uint32_t s0;
 1364    uint32_t s1;
 1365    uint32_t s2;
 1366    uint32_t p0;
 1367    uint32_t p1;
 1368    uint32_t p2;
 1369    uint32_t d0;
 1370    uint32_t d1;
 1371    uint32_t d2;
 1372    uint32_t IW;
 1373    uint32_t IH;
 1374    uint32_t ID;
 1375    uint32_t IC;
 1376    uint32_t KW;
 1377    uint32_t OH;
 1378    uint32_t KD_KH_KW;
 1379    uint32_t KH_KW;
 1380    uint32_t IC_KD_KH_KW;
 1381    uint32_t N_OD_OH;
 1382    uint32_t OD_OH;
 1383    uint32_t OD_OH_OW_IC_KD_KH_KW;
 1384    uint32_t OH_OW_IC_KD_KH_KW;
 1385    uint32_t OW_IC_KD_KH_KW;
 1386    uint32_t misalign_offsets;
 1387};
 1388
 1389struct vk_op_timestep_embedding_push_constants {
 1390    uint32_t nb1;
 1391    uint32_t dim;
 1392    uint32_t max_period;
 1393};
 1394
 1395struct vk_op_conv_transpose_1d_push_constants {
 1396    uint32_t Cout;
 1397    uint32_t Cin;
 1398    uint32_t K;
 1399    uint32_t L;
 1400    uint32_t KL;
 1401
 1402    uint32_t nb01;
 1403    uint32_t nb02;
 1404    uint32_t nb11;
 1405    uint32_t nb1;
 1406
 1407    int32_t s0;
 1408};
 1409
 1410struct vk_op_pool2d_push_constants {
 1411    uint32_t IW; uint32_t IH;
 1412    uint32_t OW; uint32_t OH;
 1413    uint32_t OC;
 1414    uint32_t pelements;
 1415    uint32_t op;
 1416    int32_t k0; int32_t k1;
 1417    int32_t s0; int32_t s1;
 1418    int32_t p0; int32_t p1;
 1419};
 1420
 1421struct vk_op_rwkv_wkv6_push_constants {
 1422    uint32_t B;
 1423    uint32_t T;
 1424    uint32_t C;
 1425    uint32_t H;
 1426};
 1427
 1428struct vk_op_rwkv_wkv7_push_constants {
 1429    uint32_t B;
 1430    uint32_t T;
 1431    uint32_t C;
 1432    uint32_t H;
 1433};
 1434struct vk_op_ssm_scan_push_constants {
 1435    uint32_t nb02, nb03, nb12, nb13;
 1436    uint32_t nb21, nb22, nb31;
 1437    uint32_t nb42, nb43, nb52, nb53;
 1438    uint32_t s_off;
 1439    uint32_t n_head, d_head, n_group, n_tok;
 1440};
 1441struct vk_op_ssm_conv_push_constants {
 1442    uint32_t nb01, nb02;
 1443    uint32_t nb11;
 1444    uint32_t dst_nb0, dst_nb1, dst_nb2;
 1445    uint32_t nc, ncs, nr, n_t, n_s;
 1446};
 1447
 1448struct vk_op_conv2d_push_constants {
 1449    uint32_t Cout;
 1450    uint32_t Cin;
 1451    uint32_t N;
 1452
 1453    uint32_t W;
 1454    uint32_t H;
 1455    uint32_t OW;
 1456    uint32_t OH;
 1457
 1458    uint32_t nb01;
 1459    uint32_t nb02;
 1460    uint32_t nb03;
 1461
 1462    uint32_t nb11;
 1463    uint32_t nb12;
 1464    uint32_t nb13;
 1465
 1466    uint32_t nb1;
 1467    uint32_t nb2;
 1468    uint32_t nb3;
 1469
 1470    // init_fastdiv_values constants for dividing by OW, OW*OH
 1471    uint32_t OWmp;   uint32_t OWL;
 1472    uint32_t OWOHmp; uint32_t OWOHL;
 1473};
 1474
 1475template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
 1476    // Compute magic values to divide by OW, OW*OH
 1477    init_fastdiv_values(p.OW,       p.OWmp,    p.OWL);
 1478    init_fastdiv_values(p.OW*p.OH,  p.OWOHmp,  p.OWOHL);
 1479}
 1480
 1481struct vk_op_conv2d_dw_push_constants {
 1482    uint32_t ne;
 1483    uint32_t batches;
 1484    uint32_t channels;
 1485    uint32_t dst_w;
 1486    uint32_t dst_h;
 1487    uint32_t src_w;
 1488    uint32_t src_h;
 1489    uint32_t knl_w;
 1490    uint32_t knl_h;
 1491    int32_t stride_x;
 1492    int32_t stride_y;
 1493    int32_t pad_x;
 1494    int32_t pad_y;
 1495    int32_t dilation_x;
 1496    int32_t dilation_y;
 1497};
 1498
 1499struct vk_op_upscale_push_constants {
 1500    uint32_t ne; uint32_t a_offset; uint32_t d_offset;
 1501    uint32_t ne00; uint32_t ne01;
 1502    uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
 1503    uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
 1504    float sf0; float sf1; float sf2; float sf3;
 1505    float pixel_offset;
 1506};
 1507
 1508struct vk_op_sum_rows_push_constants
 1509{
 1510    uint32_t n_cols;
 1511    uint32_t ne01, ne02;
 1512    uint32_t nb01, nb02, nb03;
 1513    uint32_t nb11, nb12, nb13;
 1514    float weight;
 1515    uint32_t misalign_offsets;
 1516    uint32_t ne0_12mp, ne0_12L;
 1517    uint32_t ne0_1mp, ne0_1L;
 1518};
 1519
 1520static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
 1521    uint32_t type_size = (uint32_t)ggml_type_size(src->type);
 1522    vk_op_sum_rows_push_constants p = {};
 1523    p.n_cols = (uint32_t)n_cols;
 1524    p.ne01 = (uint32_t)src->ne[1];
 1525    p.ne02 = (uint32_t)src->ne[2];
 1526    p.nb01 = (uint32_t)src->nb[1] / type_size;
 1527    p.nb02 = (uint32_t)src->nb[2] / type_size;
 1528    p.nb03 = (uint32_t)src->nb[3] / type_size;
 1529    p.nb11 = (uint32_t)dst->nb[1] / type_size;
 1530    p.nb12 = (uint32_t)dst->nb[2] / type_size;
 1531    p.nb13 = (uint32_t)dst->nb[3] / type_size;
 1532    p.weight = 1.0f;
 1533    return p;
 1534}
 1535
 1536template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
 1537    init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
 1538    init_fastdiv_values(p.ne01,        p.ne0_1mp,  p.ne0_1L);
 1539}
 1540
 1541struct vk_quantize_q8_1_push_constants {
 1542    uint32_t ne;
 1543    uint32_t num_blocks;
 1544};
 1545
 1546struct vk_op_flash_attn_split_k_reduce_push_constants {
 1547    uint32_t D;
 1548    uint32_t ne1;
 1549    uint32_t ne2;
 1550    uint32_t ne3;
 1551    uint32_t k_num;
 1552    uint32_t sinks;
 1553};
 1554
 1555struct vk_op_flash_attn_mask_opt_push_constants {
 1556    uint32_t nem0;
 1557    uint32_t nem1;
 1558    uint32_t nem2;
 1559    uint32_t nbm1;
 1560    uint32_t nbm2;
 1561    uint32_t nbm3;
 1562    uint32_t nbd1;
 1563    uint32_t nbd2;
 1564    uint32_t nbd3;
 1565};
 1566
 1567// Allow pre-recording command buffers
 1568struct vk_staging_memcpy {
 1569    vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
 1570
 1571    void * dst;
 1572    const void * src;
 1573    size_t n;
 1574};
 1575
 1576struct vk_staging_memset {
 1577    vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
 1578
 1579    void * dst;
 1580    uint32_t val;
 1581    size_t n;
 1582};
 1583
 1584struct vk_context_struct {
 1585    vk_submission * s;
 1586    std::vector<vk_sequence> seqs;
 1587
 1588    int exit_tensor_idx;
 1589
 1590    std::vector<vk_staging_memcpy> in_memcpys;
 1591    std::vector<vk_staging_memcpy> out_memcpys;
 1592    std::vector<vk_staging_memset> memsets;
 1593
 1594    vk_command_pool * p {};
 1595};
 1596typedef std::shared_ptr<vk_context_struct> vk_context;
 1597typedef std::weak_ptr<vk_context_struct> vk_context_ref;
 1598
 1599struct ggml_vk_garbage_collector {
 1600    std::vector<vk_semaphore> tl_semaphores;
 1601    std::vector<vk_semaphore> semaphores;
 1602    std::vector<vk::Event> events;
 1603    std::vector<vk_context> contexts;
 1604};
 1605
 1606static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
 1607static void ggml_vk_load_shaders(vk_device& device);
 1608static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
 1609
 1610static bool vk_memory_logger_enabled = false;
 1611
 1612#define VK_LOG_MEMORY(msg) if (vk_memory_logger_enabled) { std::cerr << "ggml_vulkan memory: " << msg << std::endl; }
 1613
 1614static std::string format_size(size_t size) {
 1615    const size_t kib = 1024;
 1616    const size_t mib = kib * 1024;
 1617    const size_t gib = mib * 1024;
 1618
 1619    std::ostringstream oss;
 1620    oss << std::fixed << std::setprecision(2);
 1621
 1622    if (size >= gib) {
 1623        oss << static_cast<double>(size) / gib << " GiB";
 1624    } else if (size >= mib) {
 1625        oss << static_cast<double>(size) / mib << " MiB";
 1626    } else if (size >= kib) {
 1627        oss << static_cast<double>(size) / kib << " KiB";
 1628    } else {
 1629        oss << size << " B";
 1630    }
 1631
 1632    return oss.str();
 1633}
 1634
 1635class vk_memory_logger {
 1636public:
 1637    vk_memory_logger(): total_device(0), total_host(0) {}
 1638    void log_allocation(vk_buffer_ref buf_ref, size_t size);
 1639    void log_deallocation(vk_buffer_ref buf_ref);
 1640
 1641private:
 1642    std::map<vk::Buffer, size_t> allocations; // Track allocations
 1643    size_t total_device;
 1644    size_t total_host;
 1645    static std::mutex log_mutex;
 1646};
 1647
 1648std::mutex vk_memory_logger::log_mutex;
 1649
 1650static bool vk_perf_logger_enabled = false;
 1651static bool vk_perf_logger_concurrent = false;
 1652static bool vk_enable_sync_logger = false;
 1653// number of calls between perf logger prints
 1654static uint32_t vk_perf_logger_frequency = 1;
 1655
 1656class vk_perf_logger {
 1657  public:
 1658    void print_timings(bool force = false) {
 1659        if (timings.empty()) {
 1660            return;
 1661        }
 1662        print_count++;
 1663        if ((print_count % vk_perf_logger_frequency) != 0 && !force) {
 1664            return;
 1665        }
 1666        print_count = 0;
 1667        uint64_t total_all_op_times = 0;
 1668        std::cerr << "----------------\nVulkan Timings:" << std::endl;
 1669        for (const auto & t : timings) {
 1670            uint64_t total_op_times = 0;
 1671            for (const auto & time : t.second) {
 1672                total_op_times += time;
 1673            }
 1674            std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
 1675                      << " us = " << (total_op_times / 1000.0) << " us";
 1676
 1677            // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
 1678            auto it = flops.find(t.first);
 1679            if (it != flops.end() && (it->second).size() == t.second.size()) {
 1680                uint64_t total_op_flops = 0;
 1681                for (const auto & elem : it->second) {
 1682                    total_op_flops += elem;
 1683                }
 1684                std::cerr << " ("
 1685                          << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
 1686                                 (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
 1687                          << " GFLOPS/s)";
 1688            }
 1689
 1690            total_all_op_times += total_op_times;
 1691
 1692            std::cerr << std::endl;
 1693        }
 1694
 1695        if (timings.size() > 0) {
 1696            std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
 1697        }
 1698
 1699        timings.clear();
 1700        flops.clear();
 1701    }
 1702
 1703    std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
 1704        *n_flops = 0;
 1705        std::string fusion_str;
 1706        if (fusion_name) {
 1707            fusion_str = fusion_name + std::string(" ");
 1708        }
 1709        if (node->op == GGML_OP_UNARY) {
 1710            return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
 1711        }
 1712        if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
 1713            const uint64_t m     = node->ne[0];
 1714            const uint64_t n     = node->ne[1];
 1715            const uint64_t k     = node->src[1]->ne[0];
 1716            const uint64_t batch = node->ne[2] * node->ne[3];
 1717            std::string    name  = ggml_op_name(node->op);
 1718            if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
 1719                (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
 1720                name += "_VEC";
 1721            }
 1722            name += " ";
 1723            name += ggml_type_name(node->src[0]->type);
 1724            name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
 1725            if (node->op == GGML_OP_MUL_MAT_ID) {
 1726                name += " n_expert=" + std::to_string(node->src[0]->ne[2]);
 1727            }
 1728            if (batch > 1) {
 1729                name += " batch=" + std::to_string(batch);
 1730            }
 1731            name = fusion_str + name;
 1732            *n_flops = m * n * (k + (k - 1)) * batch;
 1733            return name;
 1734        }
 1735        if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
 1736            std::string   name    = ggml_op_name(node->op);
 1737            ggml_tensor * knl     = node->src[0];
 1738            uint64_t      OW      = node->ne[0];
 1739            uint64_t      OH      = node->ne[1];
 1740            uint64_t      N       = node->ne[3];
 1741            uint64_t      Cout    = node->ne[2];
 1742            uint64_t      KW      = knl->ne[0];
 1743            uint64_t      KH      = knl->ne[1];
 1744            uint64_t      Cin     = node->src[1]->ne[2];
 1745            // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
 1746            uint64_t      size_M  = Cout;
 1747            uint64_t      size_K  = Cin * KW * KH;
 1748            uint64_t      size_N  = N * OW * OH;
 1749            *n_flops = size_M * size_N * (size_K + (size_K - 1));
 1750            name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
 1751                    ", N=N*OW*OH=" + std::to_string(size_N);
 1752            name = fusion_str + name;
 1753            return name;
 1754        }
 1755        if (node->op == GGML_OP_RMS_NORM) {
 1756            std::string   name    = ggml_op_name(node->op);
 1757            name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
 1758            name = fusion_str + name;
 1759            return name;
 1760        }
 1761        if (node->op == GGML_OP_FLASH_ATTN_EXT) {
 1762            const ggml_tensor * dst = node;
 1763            const ggml_tensor * q = node->src[0];
 1764            const ggml_tensor * k = node->src[1];
 1765            const ggml_tensor * v = node->src[2];
 1766            const ggml_tensor * m = node->src[3];
 1767            std::stringstream name;
 1768            name << fusion_str;
 1769            name << ggml_op_name(node->op) <<
 1770                " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
 1771                " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
 1772                " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
 1773                " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
 1774                " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
 1775            *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
 1776            return name.str();
 1777        }
 1778        if (node->op == GGML_OP_TOP_K) {
 1779            std::stringstream name;
 1780            name << fusion_str;
 1781            name << ggml_op_name(node->op) <<
 1782                " K=" << node->ne[0] <<
 1783                " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
 1784            return name.str();
 1785        }
 1786        return fusion_str + ggml_op_name(node->op);
 1787    }
 1788
 1789    void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
 1790        uint64_t n_flops;
 1791        std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
 1792        if (n_flops) {
 1793            flops[name].push_back(n_flops);
 1794        }
 1795        timings[name].push_back(time);
 1796    }
 1797
 1798    void log_timing(const std::vector<ggml_tensor *> &nodes, const std::vector<const char *> &names, uint64_t time) {
 1799        uint64_t total_flops = 0;
 1800        std::string name;
 1801        for (size_t n = 0; n < nodes.size(); ++n) {
 1802            uint64_t n_flops = 0;
 1803            name += get_node_fusion_name(nodes[n], names[n], &n_flops);
 1804            total_flops += n_flops;
 1805
 1806            if (n != nodes.size() - 1) {
 1807                name += ", ";
 1808            }
 1809        }
 1810        if (total_flops) {
 1811            flops[name].push_back(total_flops);
 1812        }
 1813        timings[name].push_back(time);
 1814    }
 1815
 1816  private:
 1817    std::map<std::string, std::vector<uint64_t>> timings;
 1818    std::map<std::string, std::vector<uint64_t>> flops;
 1819    uint32_t print_count {};
 1820};
 1821
 1822struct ggml_backend_vk_context {
 1823    std::string name;
 1824
 1825    vk_device device;
 1826
 1827    size_t semaphore_idx, event_idx;
 1828    ggml_vk_garbage_collector gc;
 1829    size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
 1830    vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials, sync_staging;
 1831    vk::Fence fence, almost_ready_fence;
 1832    bool submit_pending {};
 1833    bool almost_ready_fence_pending {};
 1834    // Set before op_add and unset after op_rms_norm to indicate that the add should
 1835    // write partial sums to accumulate the square of the vector components
 1836    bool do_add_rms_partials_offset_calculation;
 1837    bool do_add_rms_partials;
 1838
 1839    uint64_t last_total_mul_mat_bytes {};
 1840
 1841    // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
 1842    vk_pipeline_struct * prealloc_y_last_pipeline_used {};
 1843    const ggml_tensor * prealloc_y_last_tensor_used {};
 1844
 1845    // Track which nodes have been used since the last sync, and whether they were written to
 1846    std::vector<const ggml_tensor *> unsynced_nodes_written;
 1847    std::vector<const ggml_tensor *> unsynced_nodes_read;
 1848    // Track which prealloc buffers have pending reads that need to be synchronized.
 1849    // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
 1850    // and set to true after the buffer contents are consumed.
 1851    bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
 1852
 1853    vk_context_ref compute_ctx;
 1854
 1855    std::vector<vk_context_ref> tensor_ctxs;
 1856
 1857    std::vector<vk::DescriptorPool> descriptor_pools;
 1858    std::vector<vk::DescriptorSet> descriptor_sets;
 1859    uint32_t descriptor_set_idx {};
 1860    uint32_t pipeline_descriptor_set_requirements {};
 1861
 1862    vk_command_pool compute_cmd_pool;
 1863
 1864    // number of additional consecutive nodes that are being fused with the
 1865    // node currently being processed
 1866    int num_additional_fused_ops {};
 1867    // Bitmask of which fused ops need to write an intermediate value to memory.
 1868    // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
 1869    // If there's no fusion, bit 0 is still set.
 1870    int fused_ops_write_mask {};
 1871    topk_moe_mode fused_topk_moe_mode {};
 1872    bool fused_topk_moe_scale {};
 1873
 1874    // for GGML_VK_PERF_LOGGER
 1875    std::unique_ptr<vk_perf_logger> perf_logger;
 1876    vk::QueryPool query_pool;
 1877    std::vector<const char *> query_fusion_names;
 1878    std::vector<int> query_fusion_node_count;
 1879    std::vector<ggml_tensor *> query_nodes;
 1880    std::vector<int> query_node_idx;
 1881    int32_t num_queries {};
 1882    int32_t query_idx {};
 1883};
 1884
 1885static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
 1886
 1887static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
 1888    if (tensor->view_src) {
 1889        return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
 1890    }
 1891    return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
 1892}
 1893
 1894static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)
 1895{
 1896    return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
 1897}
 1898
 1899template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 1900    GGML_UNUSED(p);
 1901    GGML_UNUSED(src0);
 1902    GGML_UNUSED(src1);
 1903    GGML_UNUSED(src2);
 1904    GGML_UNUSED(src3);
 1905    GGML_UNUSED(dst);
 1906    static_assert(!std::is_const<T>::value, "unexpected type");
 1907    GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
 1908    GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
 1909    GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
 1910    GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
 1911    GGML_ASSERT(!dst  || get_misalign_bytes(ctx, dst) == 0);
 1912}
 1913
 1914template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_p021_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 1915    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
 1916    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 1917
 1918    p.b_offset = b_offset;
 1919    p.d_offset = d_offset;
 1920
 1921    GGML_UNUSED(src0);
 1922    GGML_UNUSED(src2);
 1923    GGML_UNUSED(src3);
 1924}
 1925
 1926template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_nc_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 1927    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
 1928    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 1929
 1930    p.b_offset = b_offset;
 1931    p.d_offset = d_offset;
 1932
 1933    GGML_UNUSED(src0);
 1934    GGML_UNUSED(src2);
 1935    GGML_UNUSED(src3);
 1936}
 1937
 1938struct ggml_backend_vk_buffer_context {
 1939    vk_device_ref device;
 1940    vk_buffer dev_buffer;
 1941    std::string name;
 1942
 1943    ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
 1944        device(device),
 1945        dev_buffer(dev_buffer),
 1946        name(name) {
 1947    }
 1948
 1949    ~ggml_backend_vk_buffer_context() {
 1950        ggml_vk_destroy_buffer(dev_buffer);
 1951    }
 1952};
 1953
 1954void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
 1955    if (!vk_memory_logger_enabled) {
 1956        return;
 1957    }
 1958    std::lock_guard<std::mutex> guard(log_mutex);
 1959    vk_buffer buf = buf_ref.lock();
 1960    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
 1961    const std::string type = device ? "device" : "host";
 1962    allocations[buf->buffer] = size;
 1963    total_device += device ? size : 0;
 1964    total_host += device ? 0 : size;
 1965    VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
 1966}
 1967
 1968void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
 1969    if (buf_ref.expired() || buf_ref.lock()->size == 0 || !vk_memory_logger_enabled) {
 1970        return;
 1971    }
 1972
 1973    std::lock_guard<std::mutex> guard(log_mutex);
 1974    vk_buffer buf = buf_ref.lock();
 1975    const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
 1976    std::string type = device ? "device" : "host";
 1977    auto it = allocations.find(buf->buffer);
 1978    total_device -= device ? it->second : 0;
 1979    total_host -= device ? 0 : it->second;
 1980    if (it != allocations.end()) {
 1981        VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
 1982        allocations.erase(it);
 1983    } else {
 1984        VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
 1985    }
 1986}
 1987
 1988struct vk_instance_t {
 1989    vk::Instance instance;
 1990
 1991    bool debug_utils_support = false;  // VK_EXT_debug_utils enabled
 1992    PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
 1993    PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
 1994    PFN_vkQueueEndDebugUtilsLabelEXT   pfn_vkQueueEndDebugUtilsLabelEXT   = {};
 1995    PFN_vkCmdBeginDebugUtilsLabelEXT   pfn_vkCmdBeginDebugUtilsLabelEXT   = {};
 1996    PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
 1997    PFN_vkCmdInsertDebugUtilsLabelEXT  pfn_vkCmdInsertDebugUtilsLabelEXT  = {};
 1998
 1999    std::vector<size_t> device_indices;
 2000    std::vector<bool>   device_supports_membudget;
 2001    vk_device devices[GGML_VK_MAX_DEVICES];
 2002};
 2003
 2004static bool vk_instance_initialized = false;
 2005static vk_instance_t vk_instance;
 2006
 2007#ifdef GGML_VULKAN_CHECK_RESULTS
 2008static size_t vk_skip_checks;
 2009static size_t vk_output_tensor;
 2010
 2011static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
 2012static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
 2013static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
 2014#endif
 2015
 2016typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 2017
 2018static void ggml_backend_vk_free(ggml_backend_t backend);
 2019
 2020static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
 2021    const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
 2022                                        VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
 2023    return range;
 2024}
 2025
 2026// Wait for ctx->fence to be signaled.
 2027static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
 2028    // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
 2029    // during this wait.
 2030    if (ctx->almost_ready_fence_pending) {
 2031        VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
 2032        ctx->device->device.resetFences({ ctx->almost_ready_fence });
 2033        ctx->almost_ready_fence_pending = false;
 2034    }
 2035
 2036    // Spin (w/pause) waiting for the graph to finish executing.
 2037    vk::Result result;
 2038    while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
 2039        if (result != vk::Result::eNotReady) {
 2040            fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
 2041            exit(1);
 2042        }
 2043        for (uint32_t i = 0; i < 100; ++i) {
 2044            YIELD();
 2045            YIELD();
 2046            YIELD();
 2047            YIELD();
 2048            YIELD();
 2049            YIELD();
 2050            YIELD();
 2051            YIELD();
 2052            YIELD();
 2053            YIELD();
 2054        }
 2055    }
 2056    ctx->device->device.resetFences({ ctx->fence });
 2057}
 2058
 2059// variables to track number of compiles in progress
 2060static uint32_t compile_count = 0;
 2061static std::mutex compile_count_mutex;
 2062static std::condition_variable compile_count_cond;
 2063
 2064static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
 2065                                         uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
 2066                                         bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
 2067    VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
 2068                 ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
 2069                 disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
 2070    GGML_ASSERT(parameter_count > 0);
 2071    GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);
 2072    GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
 2073
 2074    vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
 2075    pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
 2076
 2077    vk::PushConstantRange pcr(
 2078        vk::ShaderStageFlagBits::eCompute,
 2079        0,
 2080        pipeline->push_constant_size
 2081    );
 2082
 2083    vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);
 2084    pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
 2085
 2086    std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
 2087
 2088    for (size_t i = 0; i < specialization_constants.size(); i++) {
 2089        specialization_entries[i].constantID = i;
 2090        specialization_entries[i].offset = i * sizeof(uint32_t);
 2091        specialization_entries[i].size = sizeof(uint32_t);
 2092    }
 2093
 2094    vk::SpecializationInfo specialization_info(
 2095        specialization_entries.size(),
 2096        specialization_entries.data(),
 2097        specialization_constants.size() * sizeof(uint32_t),
 2098        specialization_constants.data()
 2099    );
 2100
 2101    vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
 2102
 2103    if (device->subgroup_require_full_support && require_full_subgroups) {
 2104        pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
 2105    }
 2106
 2107    vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
 2108            pipeline_shader_stage_create_flags,
 2109            vk::ShaderStageFlagBits::eCompute,
 2110            pipeline->shader_module,
 2111            entrypoint.c_str(),
 2112            &specialization_info);
 2113
 2114    vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
 2115    pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
 2116    if (device->subgroup_size_control && required_subgroup_size > 0) {
 2117        GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
 2118        pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
 2119    }
 2120
 2121    vk::ComputePipelineCreateInfo compute_pipeline_create_info(
 2122        device->pipeline_executable_properties_support ?
 2123            vk::PipelineCreateFlagBits::eCaptureStatisticsKHR :
 2124            vk::PipelineCreateFlags{},
 2125        pipeline_shader_create_info,
 2126        pipeline->layout);
 2127
 2128    vk::PipelineRobustnessCreateInfoEXT rci;
 2129
 2130    if (device->pipeline_robustness && disable_robustness) {
 2131        rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
 2132        rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
 2133        compute_pipeline_create_info.setPNext(&rci);
 2134    }
 2135
 2136#if defined(VK_EXT_shader_64bit_indexing)
 2137    vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
 2138    if (pipeline->is_64b_indexing)
 2139    {
 2140        pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
 2141        if (device->pipeline_executable_properties_support) {
 2142            pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
 2143        }
 2144        pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
 2145        compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
 2146    }
 2147#endif
 2148
 2149    try {
 2150        pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
 2151    } catch (const vk::SystemError& e) {
 2152        std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
 2153        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
 2154        throw e;
 2155    }
 2156    pipeline->compiled = true;
 2157
 2158    if (vk_instance.debug_utils_support) {
 2159        vk::DebugUtilsObjectNameInfoEXT duoni;
 2160        duoni.objectType = vk::ObjectType::ePipeline;
 2161        duoni.pObjectName = pipeline->name.c_str();
 2162        duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));
 2163        vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
 2164    }
 2165
 2166    if (device->pipeline_executable_properties_support) {
 2167        vk::PipelineExecutableInfoKHR executableInfo;
 2168        executableInfo.pipeline = pipeline->pipeline;
 2169
 2170        auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
 2171        for (auto & s : statistics) {
 2172            // "Register Count" is reported by NVIDIA drivers.
 2173            if (strcmp(s.name, "Register Count") == 0) {
 2174                VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
 2175                pipeline->register_count = (uint32_t)s.value.u64;
 2176            }
 2177        }
 2178    }
 2179
 2180    device->all_pipelines.push_back(pipeline);
 2181
 2182    {
 2183        std::lock_guard<std::mutex> guard(compile_count_mutex);
 2184        assert(compile_count > 0);
 2185        compile_count--;
 2186    }
 2187    compile_count_cond.notify_all();
 2188}
 2189
 2190static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
 2191    VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
 2192    device.destroyPipelineLayout(pipeline->layout);
 2193
 2194    device.destroyShaderModule(pipeline->shader_module);
 2195
 2196    device.destroyPipeline(pipeline->pipeline);
 2197}
 2198
 2199static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
 2200    VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
 2201    ctx->pipeline_descriptor_set_requirements += n;
 2202    if (!pipeline->compiled) {
 2203        pipeline->needed = true;
 2204        ggml_vk_load_shaders(ctx->device);
 2205    }
 2206    ggml_pipeline_allocate_descriptor_sets(ctx);
 2207}
 2208
 2209static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
 2210
 2211    if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
 2212        // Enough descriptors are available
 2213        return;
 2214    }
 2215
 2216    vk_device& device = ctx->device;
 2217
 2218    // Grow by 50% to avoid frequent allocations
 2219    uint32_t needed = std::max(3 * ctx->descriptor_sets.size() / 2, size_t{ctx->pipeline_descriptor_set_requirements});
 2220    uint32_t to_alloc = needed - ctx->descriptor_sets.size();
 2221    uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
 2222    uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
 2223
 2224    while (to_alloc > 0) {
 2225        const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
 2226        to_alloc -= alloc_count;
 2227        pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
 2228
 2229        if (pool_idx >= ctx->descriptor_pools.size()) {
 2230            vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
 2231            vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
 2232            ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
 2233        }
 2234
 2235        std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
 2236        for (uint32_t i = 0; i < alloc_count; i++) {
 2237            layouts[i] = device->dsl;
 2238        }
 2239        vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
 2240        std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
 2241        ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());
 2242
 2243        pool_idx++;
 2244    }
 2245}
 2246
 2247static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
 2248    VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
 2249
 2250    if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
 2251        // Reuse command buffer
 2252        return p.cmd_buffers[p.cmd_buffer_idx++];
 2253    }
 2254
 2255    vk::CommandBufferAllocateInfo command_buffer_alloc_info(
 2256        p.pool,
 2257        vk::CommandBufferLevel::ePrimary,
 2258        1);
 2259    const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
 2260    auto buf = cmd_buffers.front();
 2261
 2262    p.cmd_buffers.push_back(buf);
 2263    p.cmd_buffer_idx++;
 2264
 2265    return buf;
 2266}
 2267
 2268static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
 2269    if (ctx->seqs.empty()) {
 2270        if (fence) {
 2271            std::lock_guard<std::mutex> guard(queue_mutex);
 2272            ctx->p->q->queue.submit({}, fence);
 2273        }
 2274        return;
 2275    }
 2276    VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
 2277
 2278    std::vector<std::vector<uint64_t>> tl_wait_vals;
 2279    std::vector<std::vector<uint64_t>> tl_signal_vals;
 2280    std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
 2281    std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
 2282    std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
 2283    std::vector<vk::SubmitInfo> submit_infos;
 2284    int idx = -1;
 2285    std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
 2286
 2287    size_t reserve = 0;
 2288
 2289    for (const auto& sequence : ctx->seqs) {
 2290        reserve += sequence.size();
 2291    }
 2292
 2293    // Pre-reserve vectors to prevent reallocation, which invalidates pointers
 2294    tl_wait_semaphores.reserve(reserve);
 2295    tl_wait_vals.reserve(reserve);
 2296    tl_signal_semaphores.reserve(reserve);
 2297    tl_signal_vals.reserve(reserve);
 2298    tl_submit_infos.reserve(reserve);
 2299    submit_infos.reserve(reserve);
 2300    stage_flags.reserve(reserve);
 2301
 2302    for (const auto& sequence : ctx->seqs) {
 2303        for (const auto& submission : sequence) {
 2304            stage_flags.push_back({});
 2305            idx++;
 2306            tl_wait_vals.push_back({});
 2307            tl_wait_semaphores.push_back({});
 2308            tl_signal_vals.push_back({});
 2309            tl_signal_semaphores.push_back({});
 2310            for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
 2311                stage_flags[idx].push_back(ctx->p->q->stage_flags);
 2312                tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
 2313                tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
 2314            }
 2315            for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
 2316                tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
 2317                tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
 2318            }
 2319            tl_submit_infos.push_back({
 2320                (uint32_t) submission.wait_semaphores.size(),
 2321                tl_wait_vals[idx].data(),
 2322                (uint32_t) submission.signal_semaphores.size(),
 2323                tl_signal_vals[idx].data(),
 2324            });
 2325            tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
 2326            tl_submit_infos[idx].pNext = nullptr;
 2327            vk::SubmitInfo si{
 2328                (uint32_t) submission.wait_semaphores.size(),
 2329                tl_wait_semaphores[idx].data(),
 2330                stage_flags[idx].data(),
 2331                1,
 2332                &submission.buffer,
 2333                (uint32_t) submission.signal_semaphores.size(),
 2334                tl_signal_semaphores[idx].data(),
 2335            };
 2336            si.setPNext(&tl_submit_infos[idx]);
 2337            submit_infos.push_back(si);
 2338        }
 2339    }
 2340
 2341    std::lock_guard<std::mutex> guard(queue_mutex);
 2342    ctx->p->q->queue.submit(submit_infos, fence);
 2343
 2344    ctx->seqs.clear();
 2345}
 2346
 2347static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
 2348    VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
 2349    const uint32_t qfsize = queue_family_props.size();
 2350
 2351    // Try with avoid preferences first
 2352    for (uint32_t i = 0; i < qfsize; i++) {
 2353        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
 2354            return i;
 2355        }
 2356    }
 2357
 2358    // Fall back to only required
 2359    for (size_t i = 0; i < qfsize; i++) {
 2360        if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
 2361            return i;
 2362        }
 2363    }
 2364
 2365    // Fall back to reusing compute queue
 2366    for (size_t i = 0; i < qfsize; i++) {
 2367        if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
 2368            return i;
 2369        }
 2370    }
 2371
 2372    // Fall back to ignoring min_num_queries
 2373    for (size_t i = 0; i < qfsize; i++) {
 2374        if (queue_family_props[i].queueFlags & required) {
 2375            return i;
 2376        }
 2377    }
 2378
 2379    // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
 2380    // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
 2381    if (compute_index >= 0) {
 2382        return compute_index;
 2383    }
 2384
 2385    std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
 2386
 2387    for(auto &q_family : queue_family_props) {
 2388        std::cerr << "Queue number: "  + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
 2389    }
 2390    abort();
 2391}
 2392
 2393static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
 2394    VK_LOG_DEBUG("ggml_vk_create_queue()");
 2395    std::lock_guard<std::recursive_mutex> guard(device->mutex);
 2396
 2397    q.queue_family_index = queue_family_index;
 2398    q.transfer_only = transfer_only;
 2399
 2400    q.cmd_pool.init(device, &q);
 2401
 2402    q.queue = device->device.getQueue(queue_family_index, queue_index);
 2403
 2404    q.stage_flags = stage_flags;
 2405}
 2406
 2407static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {
 2408    vk_context result = std::make_shared<vk_context_struct>();
 2409    VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
 2410    ctx->gc.contexts.emplace_back(result);
 2411    result->p = &p;
 2412    return result;
 2413}
 2414
 2415static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {
 2416    vk_context result = std::make_shared<vk_context_struct>();
 2417    VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
 2418    result->p = &p;
 2419    return result;
 2420}
 2421
 2422static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
 2423    VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
 2424    vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
 2425    vk::SemaphoreCreateInfo ci{};
 2426    ci.setPNext(&tci);
 2427    vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
 2428    ctx->gc.semaphores.push_back({ semaphore, 0 });
 2429    return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
 2430}
 2431
 2432static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
 2433    VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
 2434    if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
 2435        vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
 2436        vk::SemaphoreCreateInfo ci{};
 2437        ci.setPNext(&tci);
 2438        vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
 2439        ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
 2440    }
 2441    return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
 2442}
 2443
 2444static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
 2445    if (ctx->event_idx >= ctx->gc.events.size()) {
 2446        ctx->gc.events.push_back(ctx->device->device.createEvent({}));
 2447    }
 2448    return ctx->gc.events[ctx->event_idx++];
 2449}
 2450
 2451static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {
 2452    VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()");
 2453
 2454    // Requires command buffers to be done
 2455    device->device.resetCommandPool(p.pool);
 2456    p.cmd_buffer_idx = 0;
 2457}
 2458
 2459static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
 2460    VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()");
 2461
 2462    // Arbitrary frequency to cleanup/reuse command buffers
 2463    static constexpr uint32_t cleanup_frequency = 10;
 2464
 2465    if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
 2466        ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
 2467    }
 2468    if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
 2469        ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
 2470    }
 2471}
 2472
 2473static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
 2474    std::vector<uint32_t> indices;
 2475
 2476    for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
 2477        vk::MemoryType memory_type = mem_props->memoryTypes[i];
 2478        if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
 2479            (flags & memory_type.propertyFlags) == flags &&
 2480            mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
 2481            indices.push_back(i);
 2482        }
 2483    }
 2484    return indices;
 2485}
 2486
 2487static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
 2488                                       void *import_ptr = nullptr) {
 2489    VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
 2490    if (size > device->max_buffer_size) {
 2491        throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
 2492    }
 2493
 2494    vk_buffer buf = std::make_shared<vk_buffer_struct>();
 2495
 2496    if (size == 0) {
 2497        buf->size = 0;
 2498        return buf;
 2499    }
 2500
 2501    vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
 2502    vk::MemoryAllocateFlags mem_flags {};
 2503    if (device->buffer_device_address) {
 2504        usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
 2505        mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
 2506    }
 2507
 2508    vk::BufferCreateInfo buffer_create_info{
 2509        vk::BufferCreateFlags(),
 2510        size,
 2511        usage_flags,
 2512        vk::SharingMode::eExclusive,
 2513        0,
 2514        nullptr,
 2515    };
 2516
 2517    vk::ExternalMemoryBufferCreateInfo external_memory_bci;
 2518    if (import_ptr) {
 2519        external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
 2520        buffer_create_info.setPNext(&external_memory_bci);
 2521    }
 2522
 2523    buf->buffer = device->device.createBuffer(buffer_create_info);
 2524
 2525    vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
 2526
 2527    vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
 2528
 2529    const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };
 2530
 2531    vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
 2532
 2533    if (device->memory_priority) {
 2534        mem_flags_info.setPNext(&mem_priority_info);
 2535    }
 2536
 2537    if (import_ptr) {
 2538        vk::MemoryHostPointerPropertiesEXT host_pointer_props;
 2539        try {
 2540            host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
 2541        } catch (vk::SystemError& e) {
 2542            GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
 2543            device->device.destroyBuffer(buf->buffer);
 2544            return {};
 2545        }
 2546        vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
 2547
 2548        uint32_t memory_type_idx;
 2549        vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
 2550        for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
 2551            if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
 2552                continue;
 2553            }
 2554            if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
 2555                continue;
 2556            }
 2557
 2558            vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
 2559            // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
 2560            if ((memory_type.propertyFlags & property_flags) == property_flags) {
 2561                property_flags = memory_type.propertyFlags;
 2562                break;
 2563            }
 2564        }
 2565        if (memory_type_idx == 32) {
 2566            GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
 2567            device->device.destroyBuffer(buf->buffer);
 2568            return {};
 2569        }
 2570
 2571        buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
 2572        try {
 2573            vk::ImportMemoryHostPointerInfoEXT import_info;
 2574            import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
 2575            import_info.pHostPointer = import_ptr;
 2576            import_info.setPNext(&mem_flags_info);
 2577            buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
 2578        } catch (const vk::SystemError& e) {
 2579        }
 2580    } else {
 2581        for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
 2582            const auto & req_flags = *it;
 2583
 2584            const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
 2585
 2586            if (memory_type_indices.empty()) {
 2587                continue;
 2588            }
 2589            buf->memory_property_flags = req_flags;
 2590
 2591            bool done = false;
 2592
 2593            for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
 2594                try {
 2595                    buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
 2596                    done = true;
 2597                    break;
 2598                } catch (const vk::SystemError& e) {
 2599                    // loop and retry
 2600                    // during last attempt throw the exception
 2601                    if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
 2602                        device->device.destroyBuffer(buf->buffer);
 2603                        throw e;
 2604                    }
 2605                }
 2606            }
 2607
 2608            if (done) {
 2609                break;
 2610            }
 2611        }
 2612    }
 2613
 2614    if (!buf->device_memory) {
 2615        device->device.destroyBuffer(buf->buffer);
 2616        throw vk::OutOfDeviceMemoryError("No suitable memory type found");
 2617    }
 2618
 2619    buf->ptr = nullptr;
 2620
 2621    if (import_ptr) {
 2622        buf->ptr = import_ptr;
 2623    } else {
 2624        if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
 2625            buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
 2626        }
 2627    }
 2628
 2629    device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
 2630
 2631    buf->device = device;
 2632    buf->size = size;
 2633
 2634    if (device->buffer_device_address) {
 2635        const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
 2636        buf->bda_addr = device->device.getBufferAddress(addressInfo);
 2637    }
 2638
 2639    device->memory_logger->log_allocation(buf, size);
 2640
 2641    return buf;
 2642}
 2643
 2644static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
 2645    try {
 2646        return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});
 2647    } catch (const vk::SystemError& e) {
 2648        std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
 2649        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
 2650        throw e;
 2651    }
 2652}
 2653
 2654static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
 2655    vk_buffer buf;
 2656    try {
 2657        if (device->prefer_host_memory) {
 2658            buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
 2659                                                       vk::MemoryPropertyFlagBits::eDeviceLocal});
 2660        } else if (device->uma) {
 2661            // Fall back to host memory type
 2662            buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
 2663                                                       vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
 2664        } else if (device->disable_host_visible_vidmem) {
 2665            if (device->allow_sysmem_fallback) {
 2666                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
 2667                                                           vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
 2668            } else {
 2669                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
 2670            }
 2671        } else {
 2672            // use rebar if available, otherwise fallback to device only visible memory
 2673            if (device->allow_sysmem_fallback) {
 2674                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
 2675                                                           vk::MemoryPropertyFlagBits::eDeviceLocal,
 2676                                                           vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
 2677            } else {
 2678                buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
 2679                                                           vk::MemoryPropertyFlagBits::eDeviceLocal});
 2680            }
 2681        }
 2682    } catch (const vk::SystemError& e) {
 2683        std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
 2684        std::cerr << "ggml_vulkan: " << e.what() << std::endl;
 2685        throw e;
 2686    }
 2687
 2688    return buf;
 2689}
 2690
 2691static void ggml_vk_destroy_buffer(vk_buffer& buf) {
 2692    if (buf == nullptr) {
 2693        return;
 2694    }
 2695
 2696    if (buf->device != nullptr) {
 2697        buf->device->memory_logger->log_deallocation(buf);
 2698    }
 2699
 2700    buf.reset();
 2701}
 2702
 2703static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
 2704    return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
 2705}
 2706
 2707static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
 2708    VK_LOG_DEBUG("ggml_vk_sync_buffers()");
 2709
 2710    const bool transfer_queue = subctx->p->q->transfer_only;
 2711
 2712    if (ctx) {
 2713        ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
 2714    }
 2715
 2716    subctx->s->buffer.pipelineBarrier(
 2717        subctx->p->q->stage_flags,
 2718        subctx->p->q->stage_flags,
 2719        {},
 2720        { {
 2721          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
 2722          { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
 2723        } },
 2724        {},
 2725        {}
 2726    );
 2727}
 2728
 2729static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
 2730    VK_LOG_DEBUG("ggml_vk_set_event()");
 2731
 2732    ctx->s->buffer.setEvent(
 2733        event,
 2734        ctx->p->q->stage_flags
 2735    );
 2736}
 2737
 2738static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
 2739    VK_LOG_DEBUG("ggml_vk_wait_events()");
 2740    if (events.empty()) {
 2741        return;
 2742    }
 2743
 2744    ctx->s->buffer.waitEvents(
 2745        events,
 2746        ctx->p->q->stage_flags,
 2747        ctx->p->q->stage_flags,
 2748        {},
 2749        {},
 2750        {}
 2751    );
 2752}
 2753
 2754// number of rows/cols for flash attention shader
 2755static constexpr uint32_t flash_attention_num_small_rows = 32;
 2756static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
 2757
 2758static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
 2759    if (hsv >= 192) {
 2760        return 2;
 2761    } else if ((hsv | hsk) & 8 || small_cache) {
 2762        return 4;
 2763    } else {
 2764        return 8;
 2765    }
 2766}
 2767
 2768// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
 2769// 128 threads split into four subgroups, each subgroup does 1/4
 2770// of the Bc dimension.
 2771static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
 2772static constexpr uint32_t scalar_flash_attention_Bc = 64;
 2773static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
 2774
 2775static uint32_t get_fa_num_small_rows(FaCodePath path) {
 2776    if (path == FA_COOPMAT2) {
 2777        return flash_attention_num_small_rows;
 2778    } else {
 2779        return scalar_flash_attention_num_small_rows;
 2780    }
 2781}
 2782
 2783static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
 2784    GGML_UNUSED(clamp);
 2785
 2786    if (path == FA_SCALAR) {
 2787        if (small_rows) {
 2788            return {scalar_flash_attention_num_small_rows, 64};
 2789        } else {
 2790            if ((hsv | hsk) & 8) {
 2791                // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
 2792                // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
 2793                return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
 2794            } else {
 2795                return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
 2796            }
 2797        }
 2798    }
 2799
 2800    if (path == FA_COOPMAT1) {
 2801        if (small_rows) {
 2802            return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
 2803        } else {
 2804            return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
 2805        }
 2806    }
 2807
 2808    // small rows, large cols
 2809    if (small_rows) {
 2810        return {get_fa_num_small_rows(FA_COOPMAT2), 32};
 2811    }
 2812
 2813    // small cols to reduce register count
 2814    if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
 2815        if (hsk >= 512 || hsv >= 512) {
 2816            return {32, 32};
 2817        } else {
 2818            return {64, 32};
 2819        }
 2820    }
 2821    return {64, 64};
 2822}
 2823
 2824static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
 2825    return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
 2826}
 2827
 2828static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
 2829
 2830    uint32_t lut_size = 0;
 2831    switch (src0_type) {
 2832    case GGML_TYPE_IQ1_S:
 2833    case GGML_TYPE_IQ1_M:
 2834        lut_size = 2*2048 + 4*2048;
 2835        break;
 2836    case GGML_TYPE_IQ2_XXS:
 2837        lut_size = 8*256;
 2838        break;
 2839    case GGML_TYPE_IQ2_XS:
 2840        lut_size = 8*512;
 2841        break;
 2842    case GGML_TYPE_IQ2_S:
 2843        lut_size = 8*1024;
 2844        break;
 2845    case GGML_TYPE_IQ3_XXS:
 2846        lut_size = 4*256;
 2847        break;
 2848    case GGML_TYPE_IQ3_S:
 2849        lut_size = 4*512;
 2850        break;
 2851    case GGML_TYPE_IQ4_NL:
 2852    case GGML_TYPE_IQ4_XS:
 2853    case GGML_TYPE_MXFP4:
 2854        lut_size = 4*16;
 2855        break;
 2856    default:
 2857        break;
 2858    }
 2859
 2860    // Needs to be kept up to date on shader changes
 2861    const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
 2862    const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
 2863    const uint32_t warps = warptile[0] / warptile[10];
 2864
 2865    const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
 2866    const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
 2867    const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
 2868    const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
 2869
 2870    const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
 2871    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 2872
 2873    VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
 2874                 "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
 2875
 2876    return supported;
 2877}
 2878
 2879struct GpuPipelineConfig {
 2880    // GPU architecture identifier.
 2881    // Example: vk_device_architecture::AMD_GCN
 2882    vk_device_architecture arch;
 2883
 2884    // Mapping of pipeline names to their specific subgroup sizes.
 2885    // Example: {"soft_max_f32", 64}
 2886    std::unordered_map<std::string, uint32_t> pipelines;
 2887
 2888    // Default subgroup size for this GPU.
 2889    // Defaults to 0 if not explicitly provided.
 2890    uint32_t default_subgroup_size = 0;
 2891};
 2892
 2893// Pipeline configuration for RDNA1 GPUs.
 2894static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
 2895    {"soft_max", 64}, {"im2col", 64},
 2896    {"argmax", 64}, {"mul_mat_vec", 64},
 2897    {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
 2898};
 2899
 2900// Pipeline configuration for RDNA2 GPUs.
 2901static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
 2902    {"soft_max", 64}, {"im2col", 64},
 2903};
 2904
 2905static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
 2906
 2907// Define configurations for different GPUs.
 2908static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
 2909    {
 2910        vk_device_architecture::AMD_RDNA1,
 2911        {
 2912            rdna1_pipelines,
 2913        },
 2914        RDNA_DEFAULT_SUBGROUP_SIZE
 2915    },
 2916    {
 2917        vk_device_architecture::AMD_RDNA2,
 2918        {
 2919            rdna2_pipelines,
 2920        },
 2921        RDNA_DEFAULT_SUBGROUP_SIZE
 2922    },
 2923};
 2924
 2925static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
 2926    for (const auto &config : gpu_pipeline_configs) {
 2927        if (config.arch == arch) {
 2928            auto pipIt = config.pipelines.find(pipeline_name);
 2929            if (pipIt != config.pipelines.end()) {
 2930                return pipIt->second;
 2931            }
 2932            std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
 2933            std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
 2934                      [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
 2935            for (const auto &entry : sorted_pipelines) {
 2936                if (pipeline_name.find(entry.first) != std::string::npos) {
 2937                    return entry.second;
 2938                }
 2939            }
 2940            return config.default_subgroup_size;
 2941        }
 2942    }
 2943    return 0; // If no matching configuration is found
 2944}
 2945
 2946static void ggml_vk_load_shaders(vk_device& device) {
 2947    VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
 2948
 2949    std::lock_guard<std::recursive_mutex> guard(device->mutex);
 2950    // some shaders have a minimum subgroup size
 2951    const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
 2952    const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
 2953    const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
 2954
 2955    const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
 2956    const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
 2957    const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
 2958    const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
 2959
 2960    const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
 2961                                      (device->subgroup_size_control && device->subgroup_max_size >= 16);
 2962
 2963    // mulmat
 2964    std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
 2965                          l_warptile_id, m_warptile_id, s_warptile_id,
 2966                          l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
 2967                          l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
 2968                          l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
 2969                          l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
 2970                          l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
 2971                          l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
 2972                          l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
 2973    std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
 2974                            l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
 2975                            l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
 2976                            l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
 2977
 2978    uint32_t l_align, m_align, s_align;
 2979    if (device->coopmat2) {
 2980        // spec constants and tile sizes for non-quant matmul/matmul_id
 2981        l_warptile = { 256, 128, 256, 64, 1 };
 2982        m_warptile = { 256, 128, 128, 64, 0 };
 2983        s_warptile = { 128,  64,  64, 64, 0 };
 2984        l_wg_denoms = {128, 256, 1 };
 2985        m_wg_denoms = {128, 128, 1 };
 2986        s_wg_denoms = { 64,  64, 1 };
 2987
 2988        // spec constants and tile sizes for quant matmul (non-Qi_K)
 2989        l_warptile_mmq = { 256, 128, 256, 64, 1 };
 2990        m_warptile_mmq = { 256, 128, 128, 64, 1 };
 2991        s_warptile_mmq = { 256, 32,  64, 128, 0 };
 2992        l_mmq_wg_denoms = { 128, 256, 1 };
 2993        m_mmq_wg_denoms = { 128, 128, 1 };
 2994        s_mmq_wg_denoms = { 32,  64,  1 };
 2995
 2996        // spec constants and tile sizes for quant matmul (Qi_K)
 2997        l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
 2998        m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
 2999        s_warptile_mmq_k = { 256, 32,  64, 128, 0 };
 3000        l_mmq_wg_denoms_k = { 128, 256, 1 };
 3001        m_mmq_wg_denoms_k = { 128, 128, 1 };
 3002        s_mmq_wg_denoms_k = { 32,  64,  1 };
 3003
 3004        // spec constants and tile sizes for quant matmul_id
 3005        l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
 3006        m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
 3007        s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
 3008        l_mmqid_wg_denoms = { 128, 128, 1 };
 3009        m_mmqid_wg_denoms = { 128, 64, 1 };
 3010        s_mmqid_wg_denoms = { 128, 64, 1 };
 3011
 3012        l_align = 128;
 3013        m_align =  64;
 3014        s_align =  32;
 3015    } else {
 3016        // Matrix cores require different warp group sizes
 3017        const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
 3018        const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
 3019        const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
 3020        const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
 3021        const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
 3022        const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
 3023        const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
 3024        const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
 3025        const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
 3026
 3027        const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
 3028
 3029        l_warptile = { 128,             128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
 3030        m_warptile = { 128,              64,  64, 16, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
 3031        s_warptile = { subgroup_size_32, 32,  32, 16, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
 3032
 3033        l_warptile_mmq = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
 3034        m_warptile_mmq = { 128,              64,  64, 32, subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
 3035        s_warptile_mmq = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
 3036
 3037        // Integer MMQ has a smaller shared memory profile, but heavier register use
 3038        l_warptile_mmq_int = { 128,             128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
 3039        m_warptile_mmq_int = { 128,              64,  64, 32, subgroup_size_8,     32, 2, 2, 2, 1, subgroup_size_8 };
 3040        s_warptile_mmq_int = { subgroup_size_32, 32,  32, 32, s_warptile_wm,       32, 2, 2, 1, 1, subgroup_size_8 };
 3041
 3042        // K-quants use even more registers, mitigate by setting WMITER to 1
 3043        l_warptile_mmq_int_k = { 128,               128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
 3044        m_warptile_mmq_int_k = { 128,                64,  64, 32, subgroup_size_8,     32, 1, 2, 2, 1, subgroup_size_8 };
 3045        s_warptile_mmq_int_k = { subgroup_size_32,   32,  32, 32, s_warptile_wm,       32, 1, 2, 1, 1, subgroup_size_8 };
 3046
 3047        l_warptile_id = { 128,                      128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
 3048        m_warptile_id = { 128,                       64,  64, 16, mul_mat_subgroup_size_16,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
 3049        s_warptile_id = { mul_mat_subgroup_size_16,  32,  32, 16, s_warptile_wm,                32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
 3050
 3051        l_warptile_mmqid = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
 3052        m_warptile_mmqid = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
 3053        s_warptile_mmqid = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
 3054
 3055        l_warptile_mmqid_int = { 128,                       128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
 3056        m_warptile_mmqid_int = { 128,                        64,  64, 32, mul_mat_subgroup_size_8,     32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
 3057        s_warptile_mmqid_int = { mul_mat_subgroup_size_32,   32,  32, 32, s_warptile_wm,               32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
 3058
 3059        l_warptile_mmqid_int_k = { 128,                     128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
 3060        m_warptile_mmqid_int_k = { 128,                      64,  64, 32, mul_mat_subgroup_size_16,     32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
 3061        s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32,  32, 32, s_warptile_wm,                32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
 3062
 3063        // chip specific tuning
 3064        if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
 3065            m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
 3066            m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
 3067        } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {
 3068            // This is intentionally using tx_m values, slight performance increase
 3069            l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
 3070            l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
 3071            l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
 3072        } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
 3073            // Xe2/Xe3 with coopmat enabled - warptile performance tuning
 3074            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
 3075            l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
 3076        }
 3077
 3078        l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
 3079        m_mmq_wg_denoms = m_wg_denoms = { 64,  64, 1 };
 3080        s_mmq_wg_denoms = s_wg_denoms = { 32,  32, 1 };
 3081        l_align = 128;
 3082        m_align =  64;
 3083        s_align =  32;
 3084
 3085        for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
 3086            ggml_type t = (ggml_type)i;
 3087            // Disable medium and large matrix multiplication if not enough shared memory is available
 3088            // Check mmq warptiles as the largest configuration
 3089            // Throw an error if not enough for any matrix multiplication is available
 3090            if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
 3091                std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
 3092                throw std::runtime_error("Shared memory size too small for matrix multiplication.");
 3093            } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
 3094                device->mul_mat_m[i] = false;
 3095                device->mul_mat_l[i] = false;
 3096            } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
 3097                device->mul_mat_l[i] = false;
 3098            }
 3099
 3100            // Disable mul_mat_id if not enough shared memory is available
 3101            if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
 3102                device->mul_mat_id_s[i] = false;
 3103                device->mul_mat_id_m[i] = false;
 3104                device->mul_mat_id_l[i] = false;
 3105            } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
 3106                device->mul_mat_id_m[i] = false;
 3107                device->mul_mat_id_l[i] = false;
 3108            } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
 3109                device->mul_mat_id_l[i] = false;
 3110            }
 3111        }
 3112    }
 3113
 3114    if (!device->pipeline_matmul_f32) {
 3115        device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
 3116    }
 3117    if (!device->pipeline_matmul_f32_f16) {
 3118        device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
 3119    }
 3120    if (!device->pipeline_matmul_id_f32) {
 3121        device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
 3122    }
 3123    if (!device->pipeline_matmul_bf16) {
 3124        device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
 3125    }
 3126    if (!device->pipeline_matmul_id_bf16) {
 3127        device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
 3128    }
 3129
 3130    std::vector<std::future<void>> compiles;
 3131    auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
 3132                                              uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
 3133                                              uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
 3134
 3135        if (!require_full_subgroups && required_subgroup_size == 0) {
 3136            required_subgroup_size = get_subgroup_size(name, device->architecture);
 3137        }
 3138
 3139        vk_pipeline *ptr = &base_pipeline;
 3140
 3141        int num_pipelines = 1;
 3142#if defined(VK_EXT_shader_64bit_indexing)
 3143        if (device->shader_64b_indexing) {
 3144            num_pipelines = 2;
 3145        }
 3146#endif
 3147        for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
 3148            vk_pipeline &pipeline = *ptr;
 3149            if (!pipeline) {
 3150                pipeline = std::make_shared<vk_pipeline_struct>();
 3151            }
 3152            if (!pipeline->initialized) {
 3153                pipeline->name = name;
 3154                pipeline->parameter_count = parameter_count;
 3155                pipeline->push_constant_size = push_constant_size;
 3156                pipeline->wg_denoms = wg_denoms;
 3157                pipeline->align = align;
 3158                pipeline->initialized = true;
 3159#if defined(VK_EXT_shader_64bit_indexing)
 3160                pipeline->is_64b_indexing = (i == 1);
 3161#endif
 3162            }
 3163
 3164            if (!pipeline->needed || pipeline->compiled) {
 3165                continue;
 3166            }
 3167            // TODO: We're no longer benefitting from the async compiles (shaders are
 3168            // compiled individually, as needed) and this complexity can be removed.
 3169            {
 3170                // wait until fewer than N compiles are in progress
 3171                uint32_t N = std::max(1u, std::thread::hardware_concurrency());
 3172                std::unique_lock<std::mutex> guard(compile_count_mutex);
 3173                while (compile_count >= N) {
 3174                    compile_count_cond.wait(guard);
 3175                }
 3176                compile_count++;
 3177            }
 3178
 3179            compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
 3180                                          parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
 3181        }
 3182    };
 3183
 3184    auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
 3185                                              uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
 3186                                              uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
 3187        return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint,
 3188                                       parameter_count, push_constant_size, wg_denoms, specialization_constants,
 3189                                       align, disable_robustness, require_full_subgroups, required_subgroup_size);
 3190    };
 3191
 3192    auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
 3193        return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
 3194    };
 3195
 3196    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
 3197        // For large number of rows, 128 invocations seems to work best.
 3198        // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
 3199        // can't use 256 for D==80.
 3200        // For scalar, use 128 (arbitrary)
 3201        // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
 3202        const uint32_t D = (hsk|hsv);
 3203        auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
 3204
 3205        uint32_t wg_size;
 3206        switch (path) {
 3207        case FA_COOPMAT2:
 3208            wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
 3209            break;
 3210        case FA_COOPMAT1:
 3211            wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
 3212            break;
 3213        default:
 3214            wg_size = scalar_flash_attention_workgroup_size;
 3215            break;
 3216        }
 3217
 3218        // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
 3219        // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
 3220        const uint32_t D_lsb = D ^ (D & (D-1));
 3221        uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
 3222
 3223        // Nvidia prefers shared memory use to load large tiles of K.
 3224        // Switch to loading from global memory when it would use too much shared memory.
 3225        // AMD prefers loading K directly from global memory
 3226        const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
 3227
 3228        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
 3229    };
 3230
 3231#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
 3232        for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
 3233            uint32_t HSK = fa.first.HSK; \
 3234            uint32_t HSV = fa.first.HSV; \
 3235            bool small_rows = fa.first.small_rows; \
 3236            bool small_cache = fa.first.small_cache; \
 3237            FaCodePath path = fa.first.path; \
 3238            bool aligned = fa.first.aligned; \
 3239            bool f32acc = fa.first.f32acc; \
 3240            uint32_t flags = fa.first.flags; \
 3241            if (path == FAPATH) { \
 3242                if (aligned) { \
 3243                    if (f32acc) { \
 3244                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
 3245                    } else { \
 3246                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
 3247                    } \
 3248                } else { \
 3249                    if (f32acc) { \
 3250                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
 3251                    } else { \
 3252                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
 3253                    } \
 3254                } \
 3255            } \
 3256        }
 3257
 3258    CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
 3259    CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
 3260    CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
 3261    CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
 3262#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 3263    if (device->coopmat1_fa_support) {
 3264        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
 3265        CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
 3266        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
 3267        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
 3268    }
 3269#endif
 3270#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 3271    if (device->coopmat2) {
 3272        CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
 3273        CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
 3274        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
 3275        CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
 3276        CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
 3277        CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
 3278        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
 3279        CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
 3280    }
 3281#endif
 3282#undef CREATE_FA
 3283
 3284    const int mul_mat_id_param_count = 5;
 3285
 3286#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 3287    if (device->coopmat2) {
 3288
 3289        // Create 6 variants, {s,m,l}x{unaligned,aligned}
 3290#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
 3291        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true);   \
 3292        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true);   \
 3293        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true);   \
 3294        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true);   \
 3295        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true);   \
 3296        ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true);   \
 3297
 3298        // Create 2 variants, {f16,f32} accumulator
 3299#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
 3300        CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
 3301        CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
 3302
 3303        CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
 3304#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 3305        if (device->coopmat_bf16_support) {
 3306            CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
 3307        }
 3308#endif
 3309        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3310        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3311        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3312        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3313        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3314        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
 3315        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
 3316        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
 3317        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
 3318        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
 3319        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S],   matmul_iq1_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3320        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M],   matmul_iq1_m_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3321        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3322        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3323        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S],   matmul_iq2_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3324        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3325        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S],   matmul_iq3_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3326        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3327        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3328        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4],   matmul_mxfp4_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 3329
 3330        GGML_ASSERT(device->subgroup_ballot);
 3331
 3332        CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
 3333#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 3334        if (device->coopmat_bf16_support) {
 3335            CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
 3336        }
 3337#endif
 3338        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3339        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3340        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3341        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3342        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3343        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3344        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3345        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3346        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3347        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3348        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3349        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3350        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3351        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3352        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3353        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3354        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3355        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3356        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f16,  mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3357        CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f16,   mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
 3358#undef CREATE_MM
 3359#undef CREATE_MM2
 3360    } else
 3361#endif  // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 3362#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 3363    if (device->coopmat_support) {
 3364        // Create 6 variants, {s,m,l}x{unaligned,aligned}
 3365#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 3366        if (device->mul_mat ## ID ## _l[TYPE]) \
 3367            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true);   \
 3368        if (device->mul_mat ## ID ## _m[TYPE]) \
 3369            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true);   \
 3370        if (device->mul_mat ## ID ## _s[TYPE]) \
 3371            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true);   \
 3372        if (device->mul_mat ## ID ## _l[TYPE]) \
 3373            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true);   \
 3374        if (device->mul_mat ## ID ## _m[TYPE]) \
 3375            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true);   \
 3376        if (device->mul_mat ## ID ## _s[TYPE]) \
 3377            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true);   \
 3378
 3379        // Create 2 variants, {f16,f32} accumulator
 3380#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 3381        if (device->coopmat_acc_f16_support) { \
 3382            CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 3383        } \
 3384        if (device->coopmat_acc_f32_support) { \
 3385            CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 3386        } \
 3387
 3388        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 3389        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 3390        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 3391        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
 3392#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 3393        if (device->coopmat_bf16_support) {
 3394            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
 3395        }
 3396#endif
 3397
 3398        if (device->coopmat_acc_f16_support) {
 3399            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3400            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3401            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3402            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3403            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3404
 3405            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3406            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3407            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3408            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3409            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3410            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S],   matmul_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3411            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M],   matmul_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3412            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3413            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3414            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S],   matmul_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3415            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3416            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3417            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3418            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3419            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3420        } else {
 3421            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3422            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3423            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3424            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3425            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3426
 3427            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3428            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3429            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3430            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3431            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3432            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc,   matmul_iq1_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3433            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc,   matmul_iq1_m_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3434            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3435            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3436            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3437            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3438            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3439            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3440            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3441            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 3442        }
 3443
 3444        GGML_ASSERT(device->subgroup_ballot);
 3445
 3446        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3447        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3448        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3449#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 3450        if (device->coopmat_bf16_support) {
 3451            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3452        }
 3453#endif
 3454
 3455        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3456        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3457        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3458        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3459        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3460        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3461        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3462        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3463        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3464        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3465        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3466        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3467        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3468        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3469        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3470        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3471        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3472        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3473        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3474        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
 3475#undef CREATE_MM2
 3476#undef CREATE_MM
 3477    } else
 3478#endif  // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 3479    if (device->fp16) {
 3480        // Create 6 variants, {s,m,l}x{unaligned,aligned}
 3481#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
 3482        if (device->mul_mat ## ID ## _l[TYPE]) \
 3483            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3484        if (device->mul_mat ## ID ## _m[TYPE]) \
 3485            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3486        if (device->mul_mat ## ID ## _s[TYPE]) \
 3487            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3488        if (device->mul_mat ## ID ## _l[TYPE]) \
 3489            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3490        if (device->mul_mat ## ID ## _m[TYPE]) \
 3491            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3492        if (device->mul_mat ## ID ## _s[TYPE]) \
 3493            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3494
 3495#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
 3496        if (device->mul_mat ## ID ## _l[TYPE]) { \
 3497            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC        "_l", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3498        } \
 3499        if (device->mul_mat ## ID ## _m[TYPE]) { \
 3500            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC        "_m", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3501        } \
 3502        if (device->mul_mat ## ID ## _s[TYPE]) { \
 3503            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC        "_s", NAMELC ## _len,        NAMELC ##  _data,        "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3504        } \
 3505
 3506        // Create 2 variants, {f16,f32} accumulator
 3507#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
 3508        CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
 3509        CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
 3510
 3511        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3512        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3513        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3514        CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3515
 3516        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3517
 3518        CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3519        CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3520        CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3521        CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3522        CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3523
 3524        CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3525        CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3526        CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3527        CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3528        CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3529        CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S],   matmul_iq1_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3530        CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M],   matmul_iq1_m_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3531        CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3532        CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS],  matmul_iq2_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3533        CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S],   matmul_iq2_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3534        CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3535        CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3536        CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3537        CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3538        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3539
 3540#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3541        if (device->integer_dot_product) {
 3542            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
 3543            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
 3544            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
 3545            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
 3546            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
 3547
 3548            CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
 3549
 3550            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
 3551            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
 3552            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
 3553            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
 3554            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
 3555        }
 3556#endif
 3557
 3558        if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
 3559            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3560            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3561            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3562            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3563
 3564            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3565            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3566            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3567            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3568            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3569            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3570            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3571            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3572            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3573            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3574            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_subgroup_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3575            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_subgroup_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3576            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3577            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_subgroup_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3578            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_subgroup_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3579            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3580            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_subgroup_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3581            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_subgroup_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3582            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_subgroup_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3583            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_subgroup_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3584
 3585#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3586            if (device->integer_dot_product) {
 3587                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3588                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3589                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3590                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3591                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3592
 3593                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3594
 3595                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3596                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3597                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3598                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3599                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3600            }
 3601#endif
 3602        } else {
 3603            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3604            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3605            CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3606            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3607
 3608            CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3609            CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3610            CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3611            CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3612            CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3613            CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3614            CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3615            CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3616            CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3617            CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3618            CREATE_MM2(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S],   matmul_id_iq1_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3619            CREATE_MM2(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M],   matmul_id_iq1_m_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3620            CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3621            CREATE_MM2(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS],  matmul_id_iq2_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3622            CREATE_MM2(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S],   matmul_id_iq2_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3623            CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3624            CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S],   matmul_id_iq3_s_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3625            CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS],  matmul_id_iq4_xs_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3626            CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL],  matmul_id_iq4_nl_f32,  mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3627            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4],   matmul_id_mxfp4_f32,   mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3628
 3629#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3630            if (device->integer_dot_product) {
 3631                CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3632                CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3633                CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3634                CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3635                CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3636
 3637                CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int,   vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3638
 3639                CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3640                CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3641                CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3642                CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3643                CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3644            }
 3645#endif
 3646        }
 3647#undef CREATE_MM2
 3648#undef CREATE_MMQ
 3649#undef CREATE_MM
 3650    } else {
 3651        // Create 6 variants, {s,m,l}x{unaligned,aligned}
 3652#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
 3653        if (device->mul_mat ## ID ## _l[TYPE]) \
 3654            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3655        if (device->mul_mat ## ID ## _m[TYPE]) \
 3656            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3657        if (device->mul_mat ## ID ## _s[TYPE]) \
 3658            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3659        if (device->mul_mat ## ID ## _l[TYPE]) \
 3660            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3661        if (device->mul_mat ## ID ## _m[TYPE]) \
 3662            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3663        if (device->mul_mat ## ID ## _s[TYPE]) \
 3664            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE);   \
 3665
 3666#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
 3667        if (device->mul_mat ## ID ## _l[TYPE]) \
 3668            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
 3669        if (device->mul_mat ## ID ## _m[TYPE]) \
 3670            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1);   \
 3671        if (device->mul_mat ## ID ## _s[TYPE]) \
 3672            ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1);   \
 3673
 3674        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3675        CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3676        CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3677        CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3678
 3679        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3680
 3681        CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3682        CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3683        CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3684        CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3685        CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3686
 3687        CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3688        CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3689        CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3690        CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3691        CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3692        CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc,   matmul_iq1_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3693        CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc,   matmul_iq1_m_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3694        CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3695        CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc,  matmul_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3696        CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc,   matmul_iq2_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3697        CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3698        CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3699        CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3700        CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3701        CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
 3702
 3703#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3704        if (device->integer_dot_product) {
 3705            CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
 3706            CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
 3707            CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
 3708            CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
 3709            CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
 3710
 3711            CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
 3712            CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
 3713            CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
 3714            CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
 3715            CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
 3716        }
 3717#endif
 3718
 3719        if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
 3720            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3721            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3722            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3723            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
 3724
 3725            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3726            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3727            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3728            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3729            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3730            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3731            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3732            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3733            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3734            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3735            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_subgroup_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3736            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_subgroup_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3737            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3738            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_subgroup_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3739            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_subgroup_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3740            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3741            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_subgroup_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3742            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_subgroup_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3743            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_subgroup_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3744            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_subgroup_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
 3745        } else {
 3746            CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3747            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3748            CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3749            CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3750
 3751            CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3752            CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3753            CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3754            CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3755            CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3756            CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3757            CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3758            CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3759            CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3760            CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3761            CREATE_MM(GGML_TYPE_IQ1_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc,   matmul_id_iq1_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3762            CREATE_MM(GGML_TYPE_IQ1_M,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc,   matmul_id_iq1_m_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3763            CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3764            CREATE_MM(GGML_TYPE_IQ2_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc,  matmul_id_iq2_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3765            CREATE_MM(GGML_TYPE_IQ2_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc,   matmul_id_iq2_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3766            CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3767            CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3768            CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3769            CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3770            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3771        }
 3772    }
 3773    // reusing CREATE_MM from the fp32 path
 3774    if ((device->coopmat2 || device->coopmat_support)
 3775#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 3776        && !device->coopmat_bf16_support
 3777#endif
 3778        ) {
 3779        // use scalar tile sizes
 3780        l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
 3781        m_warptile = { 128,  64,  64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
 3782        s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
 3783
 3784        l_wg_denoms = {128, 128, 1 };
 3785        m_wg_denoms = { 64,  64, 1 };
 3786        s_wg_denoms = { 32,  32, 1 };
 3787
 3788        if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
 3789            // Xe2/Xe3 - bf16 warptile performance tuning
 3790            l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
 3791        }
 3792
 3793        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
 3794        CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
 3795    }
 3796#undef CREATE_MM
 3797
 3798    // mul mat vec
 3799
 3800    // the number of rows computed per shader depends on GPU model and quant
 3801    uint32_t rm_stdq = 1;
 3802    uint32_t rm_kq = 2;
 3803    uint32_t rm_stdq_int = 1;
 3804    uint32_t rm_kq_int = 1;
 3805    auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
 3806    if (device->vendor_id == VK_VENDOR_ID_AMD) {
 3807        if (device->architecture == AMD_GCN) {
 3808            rm_stdq = 2;
 3809            rm_kq = 4;
 3810            rm_stdq_int = 4;
 3811        }
 3812    } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
 3813        rm_stdq = 2;
 3814        rm_stdq_int = 2;
 3815    }
 3816    uint32_t rm_iq = 2 * rm_kq;
 3817
 3818    const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
 3819    // Ensure a subgroup size >= 16 is available
 3820    const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
 3821
 3822    const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;
 3823    const uint32_t subgroup_size16 = std::max(subgroup_size, 16u);
 3824
 3825    const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
 3826    const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
 3827    static constexpr uint32_t mul_mat_vec_num_bindings = 5;
 3828    static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
 3829
 3830    for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
 3831        const uint32_t wg_size_subgroup   = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
 3832        const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);
 3833
 3834        const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
 3835                                            (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
 3836                                            SHADER_REDUCTION_MODE_SHMEM;
 3837
 3838        const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
 3839                                              (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
 3840                                              SHADER_REDUCTION_MODE_SHMEM;
 3841
 3842        for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
 3843            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32",  arr_dmmv_f32_f32_f32_len[reduc],  arr_dmmv_f32_f32_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
 3844            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32",  arr_dmmv_f16_f32_f32_len[reduc],  arr_dmmv_f16_f32_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
 3845            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
 3846            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3847            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3848            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3849            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3850            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3851            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3852            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3853            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3854            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3855            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3856            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i],   "mul_mat_vec_iq1_s_f32_f32",   arr_dmmv_iq1_s_f32_f32_len[reduc16],   arr_dmmv_iq1_s_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3857            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i],   "mul_mat_vec_iq1_m_f32_f32",   arr_dmmv_iq1_m_f32_f32_len[reduc16],   arr_dmmv_iq1_m_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3858            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3859            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i],  "mul_mat_vec_iq2_xs_f32_f32",  arr_dmmv_iq2_xs_f32_f32_len[reduc16],  arr_dmmv_iq2_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3860            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i],   "mul_mat_vec_iq2_s_f32_f32",   arr_dmmv_iq2_s_f32_f32_len[reduc16],   arr_dmmv_iq2_s_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3861            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3862            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f32_f32",   arr_dmmv_iq3_s_f32_f32_len[reduc16],   arr_dmmv_iq3_s_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3863            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f32_f32",  arr_dmmv_iq4_xs_f32_f32_len[reduc16],  arr_dmmv_iq4_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3864            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f32_f32",  arr_dmmv_iq4_nl_f32_f32_len[reduc16],  arr_dmmv_iq4_nl_f32_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3865            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i],   "mul_mat_vec_mxfp4_f32_f32",   arr_dmmv_mxfp4_f32_f32_len[reduc16],   arr_dmmv_mxfp4_f32_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3866
 3867            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32",  arr_dmmv_f32_f16_f32_len[reduc],  arr_dmmv_f32_f16_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
 3868            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32",  arr_dmmv_f16_f16_f32_len[reduc],  arr_dmmv_f16_f16_f32_data[reduc],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
 3869            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
 3870            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3871            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3872            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3873            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3874            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
 3875            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3876            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3877            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3878            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3879            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3880            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i],   "mul_mat_vec_iq1_s_f16_f32",   arr_dmmv_iq1_s_f16_f32_len[reduc16],   arr_dmmv_iq1_s_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3881            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i],   "mul_mat_vec_iq1_m_f16_f32",   arr_dmmv_iq1_m_f16_f32_len[reduc16],   arr_dmmv_iq1_m_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3882            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3883            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i],  "mul_mat_vec_iq2_xs_f16_f32",  arr_dmmv_iq2_xs_f16_f32_len[reduc16],  arr_dmmv_iq2_xs_f16_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3884            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i],   "mul_mat_vec_iq2_s_f16_f32",   arr_dmmv_iq2_s_f16_f32_len[reduc16],   arr_dmmv_iq2_s_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3885            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3886            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f16_f32",   arr_dmmv_iq3_s_f16_f32_len[reduc16],   arr_dmmv_iq3_s_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3887            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f16_f32",  arr_dmmv_iq4_xs_f16_f32_len[reduc16],  arr_dmmv_iq4_xs_f16_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3888            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f16_f32",  arr_dmmv_iq4_nl_f16_f32_len[reduc16],  arr_dmmv_iq4_nl_f16_f32_data[reduc16],  "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3889            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i],   "mul_mat_vec_mxfp4_f16_f32",   arr_dmmv_mxfp4_f16_f32_len[reduc16],   arr_dmmv_mxfp4_f16_f32_data[reduc16],   "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
 3890
 3891#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3892            if (device->integer_dot_product) {
 3893                const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
 3894                const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
 3895
 3896                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3897                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3898                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3899                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3900                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3901
 3902                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3903
 3904                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3905                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3906                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3907                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3908                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
 3909
 3910                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
 3911                ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
 3912
 3913            }
 3914#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
 3915        }
 3916
 3917        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",        arr_dmmv_id_f32_f32_f32_len[reduc],     arr_dmmv_id_f32_f32_f32_data[reduc],     "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
 3918        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32",        arr_dmmv_id_f16_f32_f32_len[reduc],     arr_dmmv_id_f16_f32_f32_data[reduc],     "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
 3919        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32",       arr_dmmv_id_bf16_f32_f32_len[reduc],    arr_dmmv_id_bf16_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
 3920        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32",       arr_dmmv_id_q4_0_f32_f32_len[reduc],    arr_dmmv_id_q4_0_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
 3921        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32",       arr_dmmv_id_q4_1_f32_f32_len[reduc],    arr_dmmv_id_q4_1_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
 3922        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32",       arr_dmmv_id_q5_0_f32_f32_len[reduc],    arr_dmmv_id_q5_0_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
 3923        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32",       arr_dmmv_id_q5_1_f32_f32_len[reduc],    arr_dmmv_id_q5_1_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
 3924        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32",       arr_dmmv_id_q8_0_f32_f32_len[reduc],    arr_dmmv_id_q8_0_f32_f32_data[reduc],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
 3925        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32",       arr_dmmv_id_q2_k_f32_f32_len[reduc16],    arr_dmmv_id_q2_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
 3926        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32",       arr_dmmv_id_q3_k_f32_f32_len[reduc16],    arr_dmmv_id_q3_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
 3927        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32",       arr_dmmv_id_q4_k_f32_f32_len[reduc16],    arr_dmmv_id_q4_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
 3928        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32",       arr_dmmv_id_q5_k_f32_f32_len[reduc16],    arr_dmmv_id_q5_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
 3929        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32",       arr_dmmv_id_q6_k_f32_f32_len[reduc16],    arr_dmmv_id_q6_k_f32_f32_data[reduc16],    "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
 3930        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S],   "mul_mat_vec_id_iq1_s_f32",   arr_dmmv_id_iq1_s_f32_f32_len[reduc16],   arr_dmmv_id_iq1_s_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3931        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M],   "mul_mat_vec_id_iq1_m_f32",   arr_dmmv_id_iq1_m_f32_f32_len[reduc16],   arr_dmmv_id_iq1_m_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3932        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3933        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS],  "mul_mat_vec_id_iq2_xs_f32",  arr_dmmv_id_iq2_xs_f32_f32_len[reduc16],  arr_dmmv_id_iq2_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3934        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S],   "mul_mat_vec_id_iq2_s_f32",   arr_dmmv_id_iq2_s_f32_f32_len[reduc16],   arr_dmmv_id_iq2_s_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3935        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3936        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S],   "mul_mat_vec_id_iq3_s_f32",   arr_dmmv_id_iq3_s_f32_f32_len[reduc16],   arr_dmmv_id_iq3_s_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3937        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS],  "mul_mat_vec_id_iq4_xs_f32",  arr_dmmv_id_iq4_xs_f32_f32_len[reduc16],  arr_dmmv_id_iq4_xs_f32_f32_data[reduc16],  "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3938        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL],  "mul_mat_vec_id_iq4_nl_f32",  arr_dmmv_id_iq4_nl_f32_f32_len[reduc16],  arr_dmmv_id_iq4_nl_f32_f32_data[reduc16],  "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3939        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4],   "mul_mat_vec_id_mxfp4_f32",   arr_dmmv_id_mxfp4_f32_f32_len[reduc16],   arr_dmmv_id_mxfp4_f32_f32_data[reduc16],   "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
 3940
 3941#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3942        if (device->integer_dot_product) {
 3943            const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
 3944            const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
 3945
 3946            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 3947            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 3948            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 3949            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 3950            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 3951
 3952            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
 3953
 3954            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
 3955            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
 3956            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
 3957            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
 3958            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
 3959
 3960            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
 3961            ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
 3962        }
 3963#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
 3964    }
 3965
 3966#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 3967    GGML_UNUSED(rm_stdq_int);
 3968    GGML_UNUSED(rm_kq_int);
 3969    GGML_UNUSED(rm_iq_int);
 3970#endif
 3971
 3972    // dequant shaders
 3973    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3974    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3975    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3976    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3977    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3978    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3979    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
 3980    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
 3981    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3982    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
 3983    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
 3984    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S],   "dequant_iq1_s",   dequant_iq1_s_len,   dequant_iq1_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3985    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M],   "dequant_iq1_m",   dequant_iq1_m_len,   dequant_iq1_m_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3986    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3987    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS],  "dequant_iq2_xs",  dequant_iq2_xs_len,  dequant_iq2_xs_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3988    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S],   "dequant_iq2_s",   dequant_iq2_s_len,   dequant_iq2_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3989    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3990    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S],   "dequant_iq3_s",   dequant_iq3_s_len,   dequant_iq3_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3991    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS],  "dequant_iq4_xs",  dequant_iq4_xs_len,  dequant_iq4_xs_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
 3992    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL],  "dequant_iq4_nl",  dequant_iq4_nl_len,  dequant_iq4_nl_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3993    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4],   "dequant_mxfp4",   dequant_mxfp4_len,   dequant_mxfp4_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 3994
 3995    // get_rows
 3996    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 3997    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16",  get_rows_f16_len,  get_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 3998    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 3999    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4000    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4001    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4002    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4003    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4004    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4005    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4006    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4007    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4008    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4009    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S],   "get_rows_iq1_s",   get_rows_iq1_s_len,   get_rows_iq1_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4010    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M],   "get_rows_iq1_m",   get_rows_iq1_m_len,   get_rows_iq1_m_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4011    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4012    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS],  "get_rows_iq2_xs",  get_rows_iq2_xs_len,  get_rows_iq2_xs_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4013    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S],   "get_rows_iq2_s",   get_rows_iq2_s_len,   get_rows_iq2_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4014    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4015    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S],   "get_rows_iq3_s",   get_rows_iq3_s_len,   get_rows_iq3_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4016    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs",  get_rows_iq4_xs_len,  get_rows_iq4_xs_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4017    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl",  get_rows_iq4_nl_len,  get_rows_iq4_nl_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4018    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4],   "get_rows_mxfp4",   get_rows_mxfp4_len,   get_rows_mxfp4_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4019    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32],     "get_rows_i32",     get_rows_i32_len,     get_rows_i32_data,     "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4020
 4021    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 4022    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 4023    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
 4024    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4025    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4026    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4027    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4028    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4029    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4030    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4031    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4032    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4033    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4034    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S],   "get_rows_iq1_s_f32",   get_rows_iq1_s_f32_len,   get_rows_iq1_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4035    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M],   "get_rows_iq1_m_f32",   get_rows_iq1_m_f32_len,   get_rows_iq1_m_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4036    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4037    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS],  "get_rows_iq2_xs_f32",  get_rows_iq2_xs_f32_len,  get_rows_iq2_xs_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4038    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S],   "get_rows_iq2_s_f32",   get_rows_iq2_s_f32_len,   get_rows_iq2_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4039    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4040    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S],   "get_rows_iq3_s_f32",   get_rows_iq3_s_f32_len,   get_rows_iq3_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4041    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs_f32",  get_rows_iq4_xs_f32_len,  get_rows_iq4_xs_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4042    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl_f32",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4043    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 4044
 4045    ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
 4046    ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
 4047
 4048    for (auto &it : device->pipeline_fa_mask_opt) {
 4049        auto BrBc = it.first;
 4050        ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
 4051    }
 4052
 4053    if (device->subgroup_clustered && device->subgroup_require_full_support) {
 4054        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
 4055    } else {
 4056        ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
 4057    }
 4058
 4059    for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
 4060        if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
 4061            ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
 4062        } else {
 4063            ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len,              mul_mat_vec_p021_f16_f32_data,              "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
 4064        }
 4065    }
 4066    ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
 4067
 4068    ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 4069    ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 4070
 4071    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
 4072    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
 4073    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
 4074    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
 4075
 4076    if (device->float_controls_rte_fp16 &&
 4077        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
 4078        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
 4079        ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
 4080    }
 4081
 4082    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 4083    ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 4084
 4085    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4086    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4087    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4088    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4089    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4090    ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4091    ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4092
 4093    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4094    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4095    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4096    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4097    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4098    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4099    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4100
 4101    ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
 4102    ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
 4103
 4104    if (device->float_controls_rte_fp16) {
 4105        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4106        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4107        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4108        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4109        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4110        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4111    } else {
 4112        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4113        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4114        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4115        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4116        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4117        ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
 4118    }
 4119
 4120#define SET_ROWS(itype, rte) \
 4121        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32],  "set_rows_f32" #itype,  set_rows_f32 ## itype ## rte ## _len,  set_rows_f32 ## itype ## rte ## _data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4122        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16],  "set_rows_f16" #itype,  set_rows_f16 ## itype ## rte ## _len,  set_rows_f16 ## itype ## rte ## _data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4123        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4124        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4125        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4126        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4127        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4128        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
 4129        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
 4130
 4131    if (device->float_controls_rte_fp16) {
 4132        SET_ROWS(_i32, _rte)
 4133        SET_ROWS(_i64, _rte)
 4134    } else {
 4135        SET_ROWS(_i32, )
 4136        SET_ROWS(_i64, )
 4137    }
 4138#undef SET_ROWS
 4139
 4140
 4141    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
 4142    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
 4143    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
 4144    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
 4145    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
 4146    ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
 4147
 4148    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
 4149        std::string s;
 4150        s += std::string(src0_f16 ? "_f16" : "_f32");
 4151        s += std::string(src1_f16 ? "_f16" : "_f32");
 4152        s += std::string(dst_f16 ? "_f16" : "_f32");
 4153        return s;
 4154    };
 4155
 4156    bool rte = device->float_controls_rte_fp16;
 4157#define CREATE_BINARY(name, namemod, spec, bindings) \
 4158    for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
 4159        ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
 4160                                #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
 4161                                "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
 4162
 4163    CREATE_BINARY(add, , {0}, 4)
 4164    CREATE_BINARY(add, _norepeat, {1}, 4)
 4165    CREATE_BINARY(sub, , {0}, 3)
 4166    CREATE_BINARY(sub, _norepeat, {1}, 3)
 4167    CREATE_BINARY(mul, , {0}, 3)
 4168    CREATE_BINARY(mul, _norepeat, {1}, 3)
 4169    CREATE_BINARY(div, , {0}, 3)
 4170    CREATE_BINARY(div, _norepeat, {1}, 3)
 4171    CREATE_BINARY(add_rms, , {0}, 4)
 4172    CREATE_BINARY(add_rms, _norepeat, {1}, 4)
 4173#undef CREATE_BINARY
 4174
 4175    if (device->multi_add) {
 4176        for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
 4177            ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i],     "multi_add_f32_"     + std::to_string(i+1), multi_add_f32_len,     multi_add_f32_data,     "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
 4178            ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
 4179        }
 4180    }
 4181
 4182    ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
 4183
 4184    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4185
 4186    ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4187    ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4188    ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4189
 4190    ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
 4191    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
 4192    ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
 4193    ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
 4194
 4195    ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4196
 4197    ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4198    ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4199    ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4200    ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4201
 4202    if (device->float_controls_rte_fp16) {
 4203        ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4204        ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4205    } else {
 4206        ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4207        ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4208    }
 4209
 4210    ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4211    ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4212
 4213    ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4214    ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4215
 4216    ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4217
 4218    ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
 4219
 4220    ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4221
 4222    ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4223    ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 4224
 4225#define CREATE_UNARY(name)  \
 4226    ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);  \
 4227    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4228
 4229    CREATE_UNARY(gelu)
 4230    CREATE_UNARY(gelu_erf)
 4231    CREATE_UNARY(gelu_quick)
 4232    CREATE_UNARY(silu)
 4233    CREATE_UNARY(relu)
 4234    CREATE_UNARY(xielu)
 4235    CREATE_UNARY(neg)
 4236    CREATE_UNARY(tanh)
 4237    CREATE_UNARY(sigmoid)
 4238    CREATE_UNARY(hardsigmoid)
 4239    CREATE_UNARY(hardswish)
 4240    CREATE_UNARY(abs)
 4241    CREATE_UNARY(softplus)
 4242    CREATE_UNARY(step)
 4243    CREATE_UNARY(round)
 4244    CREATE_UNARY(ceil)
 4245    CREATE_UNARY(floor)
 4246    CREATE_UNARY(trunc)
 4247#undef CREATE_UNARY
 4248
 4249#define CREATE_UNARY_RTE(name)  \
 4250    if (device->float_controls_rte_fp16) {  \
 4251        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
 4252        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
 4253    } else {    \
 4254        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
 4255        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);   \
 4256    }
 4257    CREATE_UNARY_RTE(exp)
 4258#undef CREATE_UNARY_RTE
 4259
 4260    ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4261    ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4262    ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 4263
 4264    ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4265
 4266    ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4267
 4268#define CREATE_GLU(name)  \
 4269    if (device->float_controls_rte_fp16) {  \
 4270        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
 4271        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
 4272    } else {    \
 4273        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
 4274        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
 4275    }
 4276
 4277    CREATE_GLU(geglu)
 4278    CREATE_GLU(reglu)
 4279    CREATE_GLU(swiglu)
 4280    CREATE_GLU(swiglu_oai)
 4281    CREATE_GLU(geglu_erf)
 4282    CREATE_GLU(geglu_quick)
 4283#undef CREATE_GLU
 4284
 4285    ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4286    ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4287
 4288    ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
 4289
 4290    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 4291    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
 4292    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 4293    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
 4294    ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
 4295
 4296    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32,     "soft_max_large1_f32",     soft_max_large1_f32_len,     soft_max_large1_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
 4297    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32,     "soft_max_large2_f32",     soft_max_large2_f32_len,     soft_max_large2_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
 4298    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32,     "soft_max_large3_f32",     soft_max_large3_f32_len,     soft_max_large3_f32_data,     "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
 4299    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
 4300    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
 4301    ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
 4302
 4303    ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4304    ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4305    ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4306    ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4307
 4308    if (device->float_controls_rte_fp16) {
 4309        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4310        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4311        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4312        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4313
 4314        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4315        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4316        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4317    } else {
 4318        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4319        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4320        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4321        ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4322
 4323        ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4324        ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4325        ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
 4326    }
 4327
 4328    for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
 4329        uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
 4330        if (i <= device->max_workgroup_size_log2 &&
 4331            2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
 4332            const uint32_t NCOLS_PADDED_LOG2 = i;
 4333            ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
 4334        }
 4335        const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
 4336        BLOCK_SIZE /= WG_UNROLL_FACTOR;
 4337        ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
 4338    }
 4339
 4340    for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
 4341        const uint32_t BLOCK_SIZE = 1u << i;
 4342        const uint32_t NCOLS_PADDED_LOG2 = i;
 4343        if (i <= device->max_workgroup_size_log2) {
 4344            uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
 4345                                  sizeof(int) * device->subgroup_size +
 4346                                  2 * sizeof(int) +
 4347                                  2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
 4348            if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
 4349                nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
 4350                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
 4351            } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
 4352                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
 4353            }
 4354        }
 4355    }
 4356
 4357    ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 4358
 4359    ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 4360
 4361    const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
 4362    ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32,       "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
 4363    ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
 4364    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
 4365    ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
 4366
 4367    ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
 4368
 4369    ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
 4370
 4371    for (auto &s : device->pipeline_solve_tri_f32) {
 4372        const vk_solve_tri_pipeline_state &state = s.first;
 4373
 4374        // Max number of rows to load at a time, limited by shared memory
 4375        const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
 4376        // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
 4377        const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
 4378
 4379        ggml_vk_create_pipeline(
 4380            device, s.second, "solve_tri_f32",
 4381            solve_tri_f32_len, solve_tri_f32_data, "main", 3,
 4382            sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
 4383    }
 4384
 4385#define IM2COL(bda) \
 4386    ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
 4387    ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
 4388    if (device->float_controls_rte_fp16) {  \
 4389        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
 4390        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
 4391    } else {    \
 4392        ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);   \
 4393        ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);      \
 4394    }
 4395    if (device->shader_int64 && device->buffer_device_address) {
 4396        IM2COL(_bda)
 4397    } else {
 4398        IM2COL()
 4399    }
 4400
 4401    ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
 4402
 4403    ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
 4404
 4405    ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
 4406
 4407    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 4408
 4409    ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
 4410
 4411    if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
 4412        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
 4413        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
 4414    } else {
 4415        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
 4416        ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
 4417    }
 4418
 4419    ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
 4420
 4421    ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4422
 4423    ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 4424
 4425    // conv2d, conv_transpose_2d
 4426    for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
 4427        uint32_t conv2d_WG_SIZE  = 256;
 4428        uint32_t use_collectives = 0;  // Enables subgroup ops for preventing the re-calculation of indices.
 4429        uint32_t conv2d_TS_K     = (s == CONV_SHAPE_64x32) ? 4 : 8;
 4430        uint32_t conv2d_SHMEM_PAD = 4;
 4431        vk_conv_block_size conv2d_BS = vk_conv_block_sizes[s];
 4432        bool conv2d_UNROLL = true;
 4433
 4434#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 4435        if (device->coopmat2) {
 4436            conv2d_SHMEM_PAD = 8; // 8 float16_t
 4437        }
 4438#endif
 4439
 4440        if (device->vendor_id == VK_VENDOR_ID_INTEL) {
 4441            conv2d_SHMEM_PAD = 0;
 4442            conv2d_UNROLL = false;
 4443        } else if (device->vendor_id == VK_VENDOR_ID_AMD) {
 4444            conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
 4445            if (s == CONV_SHAPE_128x128 && device->architecture != vk_device_architecture::AMD_GCN) {
 4446                conv2d_UNROLL = false;
 4447            }
 4448        }
 4449
 4450        // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
 4451        bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
 4452                                    device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
 4453        bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
 4454                                     device->architecture == vk_device_architecture::AMD_GCN;
 4455
 4456        if (device->subgroup_shuffle &&
 4457            device->vendor_id != VK_VENDOR_ID_INTEL &&   // Do not enable collectives on Intel, see PR 14316.
 4458            allow_collectives_nv &&
 4459            allow_collectives_amd) {
 4460            use_collectives = 1;
 4461            conv2d_BS.CRS   = std::min(
 4462                device->subgroup_size,
 4463                conv2d_BS.CRS);  // CRS block size should be capped at subgroup size for correctness when shuffle is used.
 4464        }
 4465
 4466        uint32_t conv2d_shmem_req =
 4467            (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
 4468        if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
 4469            conv2d_BS.CRS = 8;
 4470            if (use_collectives) {
 4471                conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
 4472            }
 4473        }
 4474
 4475        std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };
 4476        std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
 4477
 4478#define CREATE_CONV(name, type_suffix, spv_suffix) \
 4479        for (auto &c : device->pipeline_##name##type_suffix[s]) { \
 4480            const vk_conv2d_pipeline_state &state = c.first;  \
 4481            std::vector<uint32_t> spec_constants_cpy = spec_constants; \
 4482            spec_constants_cpy.push_back(state.s0); \
 4483            spec_constants_cpy.push_back(state.s1); \
 4484            spec_constants_cpy.push_back(state.p0); \
 4485            spec_constants_cpy.push_back(state.p1); \
 4486            spec_constants_cpy.push_back(state.d0); \
 4487            spec_constants_cpy.push_back(state.d1); \
 4488            spec_constants_cpy.push_back(state.KW); \
 4489            spec_constants_cpy.push_back(state.KH); \
 4490            ggml_vk_create_pipeline( \
 4491                device, c.second, #name #type_suffix, \
 4492                name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
 4493                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives);    \
 4494        }
 4495#define CREATE_CONVS(spv_suffix) \
 4496        CREATE_CONV(conv2d, _f32, spv_suffix) \
 4497        CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
 4498        CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
 4499        CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix)
 4500#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 4501        if (device->coopmat2) {
 4502            CREATE_CONVS(_cm2)
 4503        } else
 4504#endif
 4505        if (conv2d_UNROLL) {
 4506            CREATE_CONVS(_unroll)
 4507        } else {
 4508            CREATE_CONVS( )
 4509        }
 4510#undef CREATE_CONV
 4511#undef CREATE_CONVS
 4512    }
 4513
 4514    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 4515    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 4516    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 4517    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 4518
 4519    for (uint32_t use_push = 0; use_push < 2; ++use_push) {
 4520        for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
 4521            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
 4522        }
 4523    }
 4524
 4525    for (auto &c : compiles) {
 4526        c.wait();
 4527    }
 4528}
 4529
 4530static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
 4531
 4532static vk_device ggml_vk_get_device(size_t idx) {
 4533    VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
 4534
 4535    if (vk_instance.devices[idx] == nullptr) {
 4536        VK_LOG_DEBUG("Initializing new vk_device");
 4537        vk_device device = std::make_shared<vk_device_struct>();
 4538        vk_instance.devices[idx] = device;
 4539
 4540        device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
 4541
 4542        size_t dev_num = vk_instance.device_indices[idx];
 4543
 4544        std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
 4545
 4546        if (dev_num >= physical_devices.size()) {
 4547            std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
 4548            throw std::runtime_error("Device not found");
 4549        }
 4550
 4551        device->physical_device = physical_devices[dev_num];
 4552        const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
 4553
 4554        device->architecture = get_device_architecture(device->physical_device);
 4555
 4556        const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
 4557        device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
 4558
 4559        const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
 4560        device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
 4561
 4562        const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
 4563        device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
 4564
 4565        const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE");
 4566        device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;
 4567
 4568        bool fp16_storage = false;
 4569        bool fp16_compute = false;
 4570        bool maintenance4_support = false;
 4571        bool sm_builtins = false;
 4572        bool amd_shader_core_properties2 = false;
 4573        bool pipeline_robustness = false;
 4574        bool coopmat2_support = false;
 4575        bool pipeline_executable_properties_support = false;
 4576        device->coopmat_support = false;
 4577        device->integer_dot_product = false;
 4578        device->shader_64b_indexing = false;
 4579        bool bfloat16_support = false;
 4580
 4581        for (const auto& properties : ext_props) {
 4582            if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
 4583                maintenance4_support = true;
 4584            } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
 4585                fp16_storage = true;
 4586            } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
 4587                fp16_compute = true;
 4588            } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
 4589                sm_builtins = true;
 4590            } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
 4591                amd_shader_core_properties2 = true;
 4592            } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
 4593                pipeline_robustness = true;
 4594            } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
 4595                device->subgroup_size_control = true;
 4596#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 4597            } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
 4598                       !getenv("GGML_VK_DISABLE_COOPMAT")) {
 4599                device->coopmat_support = true;
 4600                device->coopmat_m = 0;
 4601                device->coopmat_n = 0;
 4602                device->coopmat_k = 0;
 4603#endif
 4604#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 4605            } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
 4606                       !getenv("GGML_VK_DISABLE_COOPMAT2")) {
 4607                coopmat2_support = true;
 4608#endif
 4609#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 4610            } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
 4611                       !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
 4612                device->integer_dot_product = true;
 4613#endif
 4614#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 4615            } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
 4616                       !getenv("GGML_VK_DISABLE_BFLOAT16")) {
 4617                bfloat16_support = true;
 4618#endif
 4619            } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
 4620                pipeline_executable_properties_support = true;
 4621            } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
 4622                       getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
 4623                device->memory_priority = true;
 4624            } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
 4625                device->external_memory_host = true;
 4626#if defined(VK_EXT_shader_64bit_indexing)
 4627            } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
 4628                device->shader_64b_indexing = true;
 4629#endif
 4630            }
 4631        }
 4632
 4633        vk::PhysicalDeviceProperties2 props2;
 4634        vk::PhysicalDeviceMaintenance3Properties props3;
 4635        vk::PhysicalDeviceMaintenance4Properties props4;
 4636        vk::PhysicalDeviceSubgroupProperties subgroup_props;
 4637        vk::PhysicalDeviceDriverProperties driver_props;
 4638        vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
 4639        vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
 4640        vk::PhysicalDeviceVulkan11Properties vk11_props;
 4641        vk::PhysicalDeviceVulkan12Properties vk12_props;
 4642        vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
 4643        vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
 4644        vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
 4645
 4646        props2.pNext = &props3;
 4647        props3.pNext = &subgroup_props;
 4648        subgroup_props.pNext = &driver_props;
 4649        driver_props.pNext = &vk11_props;
 4650        vk11_props.pNext = &vk12_props;
 4651
 4652        VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
 4653
 4654        if (maintenance4_support) {
 4655            last_struct->pNext = (VkBaseOutStructure *)&props4;
 4656            last_struct = (VkBaseOutStructure *)&props4;
 4657        }
 4658        if (sm_builtins) {
 4659            last_struct->pNext = (VkBaseOutStructure *)&sm_props;
 4660            last_struct = (VkBaseOutStructure *)&sm_props;
 4661        }
 4662        if (amd_shader_core_properties2) {
 4663            last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
 4664            last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
 4665        }
 4666        if (device->subgroup_size_control) {
 4667            last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
 4668            last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
 4669        }
 4670
 4671#if defined(VK_NV_cooperative_matrix2)
 4672        vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
 4673        if (coopmat2_support) {
 4674            last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
 4675            last_struct = (VkBaseOutStructure *)&coopmat2_props;
 4676        }
 4677#endif
 4678
 4679        if (device->integer_dot_product) {
 4680            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
 4681            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
 4682        }
 4683
 4684        if (device->external_memory_host) {
 4685            last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
 4686            last_struct = (VkBaseOutStructure *)&external_memory_host_props;
 4687        }
 4688
 4689        device->physical_device.getProperties2(&props2);
 4690        device->properties = props2.properties;
 4691        device->vendor_id = device->properties.vendorID;
 4692        device->driver_id = driver_props.driverID;
 4693
 4694        if (device->driver_id == vk::DriverId::eMoltenvk) {
 4695            // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
 4696            // is available in the Vulkan SDK.
 4697            device->external_memory_host = false;
 4698        }
 4699
 4700        // Implementing the async backend interfaces seems broken on older Intel HW,
 4701        // see https://github.com/ggml-org/llama.cpp/issues/17302.
 4702        device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
 4703                                 std::string(device->properties.deviceName.data()).find("(DG1)") == std::string::npos) &&
 4704                                getenv("GGML_VK_DISABLE_ASYNC") == nullptr;
 4705
 4706        if (!device->support_async) {
 4707            GGML_LOG_DEBUG("ggml_vulkan: WARNING: Async execution disabled on certain Intel devices.\n");
 4708        }
 4709
 4710        const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
 4711
 4712        if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
 4713            device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
 4714        } else if (maintenance4_support) {
 4715            device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
 4716        } else {
 4717            device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
 4718        }
 4719
 4720        const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE");
 4721
 4722        if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
 4723            device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);
 4724        } else if (maintenance4_support) {
 4725            device->max_buffer_size = props4.maxBufferSize;
 4726        } else {
 4727            device->max_buffer_size = device->max_memory_allocation_size;
 4728        }
 4729
 4730        const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
 4731
 4732        if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
 4733            device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
 4734        } else {
 4735            // Limit batching of allocations to 1GB by default to avoid fragmentation issues
 4736            device->suballocation_block_size = 1024*1024*1024;
 4737        }
 4738        device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
 4739
 4740        device->subgroup_size = subgroup_props.subgroupSize;
 4741        device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
 4742        device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
 4743        if (sm_builtins) {
 4744            device->shader_core_count = sm_props.shaderSMCount;
 4745        } else if (amd_shader_core_properties2) {
 4746            device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
 4747        } else {
 4748            device->shader_core_count = 0;
 4749        }
 4750        device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
 4751
 4752        device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
 4753                                 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
 4754        device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
 4755                                      (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
 4756#ifdef __APPLE__
 4757        // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846)
 4758        if (device->vendor_id == VK_VENDOR_ID_AMD) {
 4759            device->subgroup_arithmetic = false;
 4760        }
 4761#endif
 4762        device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
 4763                                   (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
 4764        device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
 4765                                     (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
 4766
 4767        device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
 4768                                  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
 4769
 4770        device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
 4771                                (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
 4772
 4773        const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
 4774
 4775        device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 4776
 4777        if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
 4778            device->coopmat_support = false;
 4779        }
 4780
 4781        device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
 4782
 4783        device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
 4784
 4785        device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
 4786
 4787        std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
 4788
 4789        // Try to find a non-graphics compute queue and transfer-focused queues
 4790        const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
 4791        const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
 4792
 4793        const float priorities[] = { 1.0f, 1.0f };
 4794        device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
 4795
 4796        std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
 4797        if (compute_queue_family_index != transfer_queue_family_index) {
 4798            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
 4799            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
 4800        } else if(!device->single_queue) {
 4801            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
 4802        } else {
 4803            device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
 4804        }
 4805        vk::DeviceCreateInfo device_create_info;
 4806        std::vector<const char *> device_extensions;
 4807        vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
 4808
 4809        VkPhysicalDeviceFeatures2 device_features2;
 4810        device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
 4811        device_features2.pNext = nullptr;
 4812        device_features2.features = (VkPhysicalDeviceFeatures)device_features;
 4813
 4814        VkPhysicalDeviceVulkan11Features vk11_features;
 4815        vk11_features.pNext = nullptr;
 4816        vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
 4817        device_features2.pNext = &vk11_features;
 4818
 4819        VkPhysicalDeviceVulkan12Features vk12_features;
 4820        vk12_features.pNext = nullptr;
 4821        vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
 4822        vk11_features.pNext = &vk12_features;
 4823
 4824        last_struct = (VkBaseOutStructure *)&vk12_features;
 4825
 4826        VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
 4827        pl_robustness_features.pNext = nullptr;
 4828        pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
 4829        pl_robustness_features.pipelineRobustness = VK_FALSE;
 4830
 4831        if (pipeline_robustness) {
 4832            last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
 4833            last_struct = (VkBaseOutStructure *)&pl_robustness_features;
 4834            device_extensions.push_back("VK_EXT_pipeline_robustness");
 4835        }
 4836
 4837        VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features;
 4838        memory_priority_features.pNext = nullptr;
 4839        memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT;
 4840        memory_priority_features.memoryPriority = VK_FALSE;
 4841        if (device->memory_priority) {
 4842            last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features;
 4843            last_struct = (VkBaseOutStructure *)&memory_priority_features;
 4844            device_extensions.push_back("VK_EXT_memory_priority");
 4845        }
 4846
 4847        VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
 4848        subgroup_size_control_features.pNext = nullptr;
 4849        subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
 4850        subgroup_size_control_features.computeFullSubgroups = false;
 4851        subgroup_size_control_features.subgroupSizeControl = false;
 4852
 4853        if (device->subgroup_size_control) {
 4854            last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
 4855            last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
 4856        }
 4857
 4858#if defined(VK_KHR_cooperative_matrix)
 4859        VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
 4860        coopmat_features.pNext = nullptr;
 4861        coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
 4862        coopmat_features.cooperativeMatrix = VK_FALSE;
 4863
 4864        if (device->coopmat_support) {
 4865            last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
 4866            last_struct = (VkBaseOutStructure *)&coopmat_features;
 4867        }
 4868#endif
 4869
 4870#if defined(VK_NV_cooperative_matrix2)
 4871        VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
 4872        coopmat2_features.pNext = nullptr;
 4873        coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
 4874        if (coopmat2_support) {
 4875            last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
 4876            last_struct = (VkBaseOutStructure *)&coopmat2_features;
 4877            device_extensions.push_back("VK_NV_cooperative_matrix2");
 4878        }
 4879#endif
 4880
 4881#if defined(VK_KHR_shader_bfloat16)
 4882        VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
 4883        bfloat16_features.pNext = nullptr;
 4884        bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
 4885        if (bfloat16_support) {
 4886            last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
 4887            last_struct = (VkBaseOutStructure *)&bfloat16_features;
 4888            device_extensions.push_back("VK_KHR_shader_bfloat16");
 4889        }
 4890#endif
 4891
 4892        VkPhysicalDeviceMaintenance4Features maint4_features {};
 4893        maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
 4894        if (maintenance4_support) {
 4895            last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
 4896            last_struct = (VkBaseOutStructure *)&maint4_features;
 4897            device_extensions.push_back("VK_KHR_maintenance4");
 4898        }
 4899
 4900        VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
 4901        shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
 4902        if (device->integer_dot_product) {
 4903            last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
 4904            last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
 4905            device_extensions.push_back("VK_KHR_shader_integer_dot_product");
 4906        }
 4907
 4908        VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
 4909        pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
 4910        if (pipeline_executable_properties_support) {
 4911            last_struct->pNext = (VkBaseOutStructure *)&pep_features;
 4912            last_struct = (VkBaseOutStructure *)&pep_features;
 4913            device_extensions.push_back("VK_KHR_pipeline_executable_properties");
 4914        }
 4915
 4916        if (device->external_memory_host) {
 4917            device_extensions.push_back("VK_EXT_external_memory_host");
 4918        }
 4919
 4920#if defined(VK_EXT_shader_64bit_indexing)
 4921        VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
 4922        shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
 4923        if (device->shader_64b_indexing) {
 4924            last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
 4925            last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
 4926            device_extensions.push_back("VK_EXT_shader_64bit_indexing");
 4927        }
 4928#endif
 4929
 4930        vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
 4931
 4932        device->pipeline_executable_properties_support = pipeline_executable_properties_support;
 4933
 4934        device->fp16 = device->fp16 && vk12_features.shaderFloat16;
 4935
 4936#if defined(VK_KHR_shader_bfloat16)
 4937        device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
 4938#else
 4939        device->bf16 = false;
 4940#endif
 4941
 4942        device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
 4943
 4944        device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
 4945                            device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
 4946                            getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
 4947
 4948        device->shader_int64 = device_features2.features.shaderInt64;
 4949        device->buffer_device_address = vk12_features.bufferDeviceAddress;
 4950        device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
 4951
 4952        if (device->subgroup_size_control) {
 4953            device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
 4954            device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
 4955            device_extensions.push_back("VK_EXT_subgroup_size_control");
 4956        }
 4957
 4958        device->subgroup_size_control = device->subgroup_size_control &&
 4959                (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
 4960                subgroup_size_control_features.subgroupSizeControl;
 4961
 4962        device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
 4963
 4964#if defined(VK_KHR_cooperative_matrix)
 4965        device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
 4966
 4967        // coopmat1 fa shader currently assumes 32 invocations per subgroup
 4968        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
 4969                                      device->subgroup_size_control && device->subgroup_min_size <= 32 &&
 4970                                      device->subgroup_max_size >= 32;
 4971#endif
 4972
 4973        if (coopmat2_support) {
 4974#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 4975            if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
 4976                coopmat2_features.cooperativeMatrixFlexibleDimensions &&
 4977                coopmat2_features.cooperativeMatrixReductions &&
 4978                coopmat2_features.cooperativeMatrixConversions &&
 4979                coopmat2_features.cooperativeMatrixPerElementOperations &&
 4980                coopmat2_features.cooperativeMatrixTensorAddressing &&
 4981                coopmat2_features.cooperativeMatrixBlockLoads &&
 4982                vk12_features.bufferDeviceAddress) {
 4983
 4984                std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
 4985                uint32_t count = 0;
 4986
 4987                PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
 4988                    _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
 4989                        (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
 4990                        vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
 4991
 4992                _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
 4993
 4994                VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
 4995                empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
 4996                flexible_dimensions.resize(count, empty_prop);
 4997
 4998                _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
 4999
 5000                bool found_fp16_128 = false,
 5001                     found_fp16_256 = false,
 5002                     found_fp32_128 = false,
 5003                     found_fp32_256 = false;
 5004                // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
 5005                // with 32x16x16 and 256 with 32x32x16.
 5006                for (auto &prop : flexible_dimensions) {
 5007                    if (prop.saturatingAccumulation == VK_FALSE &&
 5008                        prop.scope == VK_SCOPE_WORKGROUP_KHR &&
 5009                        prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
 5010                        prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
 5011
 5012                        if (prop.workgroupInvocations == 128 &&
 5013                            prop.MGranularity <= 32 &&
 5014                            prop.NGranularity <= 16 &&
 5015                            prop.KGranularity <= 16) {
 5016                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
 5017                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
 5018                                found_fp16_128 = true;
 5019                            }
 5020                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
 5021                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
 5022                                found_fp32_128 = true;
 5023                            }
 5024                        }
 5025                        if (prop.workgroupInvocations == 256 &&
 5026                            prop.MGranularity <= 32 &&
 5027                            prop.NGranularity <= 32 &&
 5028                            prop.KGranularity <= 16) {
 5029                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
 5030                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
 5031                                found_fp16_256 = true;
 5032                            }
 5033                            if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
 5034                                prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
 5035                                found_fp32_256 = true;
 5036                            }
 5037                        }
 5038                    }
 5039                }
 5040                if (found_fp16_128 && found_fp16_256 &&
 5041                    found_fp32_128 && found_fp32_256 &&
 5042                    coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
 5043                    device->coopmat2 = true;
 5044                }
 5045            }
 5046#endif
 5047        }
 5048
 5049        if (!vk11_features.storageBuffer16BitAccess) {
 5050            std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
 5051            throw std::runtime_error("Unsupported device");
 5052        }
 5053
 5054        device_extensions.push_back("VK_KHR_16bit_storage");
 5055
 5056#ifdef GGML_VULKAN_VALIDATE
 5057        device_extensions.push_back("VK_KHR_shader_non_semantic_info");
 5058#endif
 5059
 5060        if (device->fp16) {
 5061            device_extensions.push_back("VK_KHR_shader_float16_int8");
 5062        }
 5063
 5064#if defined(VK_KHR_cooperative_matrix)
 5065        if (device->coopmat_support) {
 5066            // Query supported shapes
 5067            std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
 5068
 5069            PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
 5070                (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
 5071
 5072            uint32_t cm_props_num;
 5073
 5074            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
 5075
 5076            cm_props.resize(cm_props_num);
 5077
 5078            for (auto& prop : cm_props) {
 5079                prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
 5080            }
 5081
 5082            pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
 5083
 5084            VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
 5085
 5086            for (auto& prop : cm_props) {
 5087                VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
 5088
 5089                if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
 5090                    (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
 5091                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
 5092                ) {
 5093                    if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
 5094                        (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
 5095                        // coopmat sizes not set yet
 5096                        if (device->coopmat_m == 0) {
 5097                            device->coopmat_acc_f32_support = true;
 5098                            device->coopmat_m = prop.MSize;
 5099                            device->coopmat_n = prop.NSize;
 5100                            device->coopmat_k = prop.KSize;
 5101                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
 5102                            // Only enable if shape is identical
 5103                            device->coopmat_acc_f32_support = true;
 5104                        }
 5105                        if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
 5106                            device->coopmat_support_16x16x16_f32acc = true;
 5107                        }
 5108                    } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
 5109                               (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
 5110                        // coopmat sizes not set yet
 5111                        if (device->coopmat_m == 0) {
 5112                            device->coopmat_acc_f16_support = true;
 5113                            device->coopmat_m = prop.MSize;
 5114                            device->coopmat_n = prop.NSize;
 5115                            device->coopmat_k = prop.KSize;
 5116                        } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
 5117                            // Only enable if shape is identical
 5118                            device->coopmat_acc_f16_support = true;
 5119                        }
 5120                        if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
 5121                            device->coopmat_support_16x16x16_f16acc = true;
 5122                        }
 5123                    }
 5124                } else if ((vk::ComponentTypeKHR)prop.AType      == vk::ComponentTypeKHR::eSint8 &&
 5125                           (vk::ComponentTypeKHR)prop.BType      == vk::ComponentTypeKHR::eSint8 &&
 5126                           (vk::ComponentTypeKHR)prop.CType      == vk::ComponentTypeKHR::eSint32 &&
 5127                           (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
 5128                           (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
 5129                           device->coopmat_int_m == 0
 5130                ) {
 5131                    device->coopmat_int_support = true;
 5132                    device->coopmat_int_m = prop.MSize;
 5133                    device->coopmat_int_n = prop.NSize;
 5134                    device->coopmat_int_k = prop.KSize;
 5135                }
 5136#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 5137                if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
 5138                    prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
 5139                    prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
 5140                    prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
 5141                    (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
 5142                ) {
 5143                    // coopmat sizes not set yet
 5144                    if (device->coopmat_m == 0) {
 5145                        device->coopmat_bf16_support = true;
 5146                        device->coopmat_m = prop.MSize;
 5147                        device->coopmat_n = prop.NSize;
 5148                        device->coopmat_k = prop.KSize;
 5149                    } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
 5150                        // Only enable if shape is identical
 5151                        device->coopmat_bf16_support = true;
 5152                    }
 5153                }
 5154#endif
 5155            }
 5156
 5157            if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
 5158                // No suitable matmul mode found
 5159                GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
 5160                device->coopmat_support = false;
 5161            }
 5162            if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
 5163                device->coopmat_bf16_support = false;
 5164            }
 5165        }
 5166
 5167        if (device->coopmat_support) {
 5168            device_extensions.push_back("VK_KHR_cooperative_matrix");
 5169        }
 5170#if defined(VK_KHR_shader_bfloat16)
 5171        if (device->coopmat_bf16_support) {
 5172            device_extensions.push_back("VK_KHR_shader_bfloat16");
 5173        }
 5174#endif
 5175#endif
 5176        device->name = GGML_VK_NAME + std::to_string(idx);
 5177
 5178        device_create_info = {
 5179            vk::DeviceCreateFlags(),
 5180            device_queue_create_infos,
 5181            {},
 5182            device_extensions
 5183        };
 5184        device_create_info.setPNext(&device_features2);
 5185        device->device = device->physical_device.createDevice(device_create_info);
 5186
 5187        // Queues
 5188        ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
 5189
 5190        // Shaders
 5191        // Disable matmul tile sizes early if performance low or not supported
 5192        for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
 5193            switch (device->vendor_id) {
 5194#ifndef GGML_VULKAN_RUN_TESTS
 5195            case VK_VENDOR_ID_AMD:
 5196                device->mul_mat_l[i]    = device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary;
 5197                device->mul_mat_m[i]    = true;
 5198                device->mul_mat_s[i]    = true;
 5199                device->mul_mat_id_l[i] = false;
 5200                device->mul_mat_id_m[i] = true;
 5201                device->mul_mat_id_s[i] = true;
 5202                break;
 5203            case VK_VENDOR_ID_INTEL:
 5204                if (!device->coopmat_support || device->architecture != INTEL_XE2) {
 5205                    device->mul_mat_l[i] = false;
 5206                    device->mul_mat_id_l[i] = false;
 5207                } else {
 5208                    device->mul_mat_l[i] = true;  // if coopmat & XE2+, allow large matmul warptile config for Intel
 5209                    device->mul_mat_id_l[i] = true;
 5210                }
 5211                device->mul_mat_m[i] = true;
 5212                device->mul_mat_s[i] = true;
 5213                device->mul_mat_id_m[i] = true;
 5214                device->mul_mat_id_s[i] = true;
 5215                break;
 5216            case VK_VENDOR_ID_APPLE:
 5217                device->mul_mat_l[i] = false;
 5218                device->mul_mat_m[i] = true;
 5219                device->mul_mat_s[i] = false;
 5220                device->mul_mat_id_l[i] = false;
 5221                device->mul_mat_id_m[i] = true;
 5222                device->mul_mat_id_s[i] = false;
 5223                break;
 5224#endif
 5225            default:
 5226                device->mul_mat_l[i] = true;
 5227                device->mul_mat_m[i] = true;
 5228                device->mul_mat_s[i] = true;
 5229                device->mul_mat_id_l[i] = true;
 5230                device->mul_mat_id_m[i] = true;
 5231                device->mul_mat_id_s[i] = true;
 5232                break;
 5233            }
 5234        }
 5235
 5236
 5237        std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
 5238        std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
 5239        for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {
 5240            dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
 5241            dsl_binding_flags.push_back({});
 5242        }
 5243
 5244        vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
 5245
 5246        vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
 5247            {},
 5248            dsl_binding);
 5249        descriptor_set_layout_create_info.setPNext(&dslbfci);
 5250        device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
 5251
 5252        ggml_vk_load_shaders(device);
 5253
 5254        if (!device->single_queue) {
 5255            const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
 5256            ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
 5257        } else {
 5258            // TODO: Use pointer or reference to avoid copy
 5259            device->transfer_queue.copyFrom(device->compute_queue);
 5260            device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
 5261        }
 5262
 5263        device->buffer_type = {
 5264            /* .iface    = */ ggml_backend_vk_buffer_type_interface,
 5265            /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
 5266            /* .context  = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
 5267        };
 5268
 5269        device->fence = device->device.createFence({});
 5270
 5271        device->idx = idx;
 5272
 5273        device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
 5274
 5275        device->add_rms_fusion = !device->disable_fusion &&
 5276                                 device->subgroup_arithmetic &&
 5277                                 device->vendor_id != VK_VENDOR_ID_INTEL;
 5278        device->partials_binding_alignment =
 5279            std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
 5280
 5281        device->mmvq_mode = 0;
 5282        if (getenv("GGML_VK_DISABLE_MMVQ")) {
 5283            device->mmvq_mode = -1;
 5284        } else if (getenv("GGML_VK_FORCE_MMVQ")) {
 5285            device->mmvq_mode = 1;
 5286        }
 5287
 5288        return device;
 5289    }
 5290
 5291    return vk_instance.devices[idx];
 5292}
 5293
 5294static void ggml_vk_print_gpu_info(size_t idx) {
 5295    GGML_ASSERT(idx < vk_instance.device_indices.size());
 5296    size_t dev_num = vk_instance.device_indices[idx];
 5297    VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
 5298    GGML_ASSERT(vk_instance_initialized);
 5299
 5300    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
 5301
 5302    if (dev_num >= devices.size()) {
 5303        std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
 5304        throw std::runtime_error("Device not found");
 5305    }
 5306
 5307    vk::PhysicalDevice physical_device = devices[dev_num];
 5308    std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
 5309
 5310    bool fp16_storage = false;
 5311    bool fp16_compute = false;
 5312    bool coopmat_support = false;
 5313    bool coopmat2_support = false;
 5314    bool integer_dot_product = false;
 5315    bool bfloat16_support = false;
 5316
 5317    for (auto properties : ext_props) {
 5318        if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
 5319            fp16_storage = true;
 5320        } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
 5321            fp16_compute = true;
 5322#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 5323       } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
 5324                   !getenv("GGML_VK_DISABLE_COOPMAT")) {
 5325            coopmat_support = true;
 5326#endif
 5327#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
 5328        } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
 5329                   !getenv("GGML_VK_DISABLE_COOPMAT2")) {
 5330            coopmat2_support = true;
 5331#endif
 5332#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
 5333        } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
 5334                    !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
 5335            integer_dot_product = true;
 5336#endif
 5337#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
 5338        } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
 5339                    !getenv("GGML_VK_DISABLE_BFLOAT16")) {
 5340            bfloat16_support = true;
 5341#endif
 5342        }
 5343    }
 5344
 5345    const vk_device_architecture device_architecture = get_device_architecture(physical_device);
 5346
 5347    const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
 5348    bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
 5349
 5350    bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
 5351
 5352    vk::PhysicalDeviceProperties2 props2;
 5353    vk::PhysicalDeviceMaintenance3Properties props3;
 5354    vk::PhysicalDeviceSubgroupProperties subgroup_props;
 5355    vk::PhysicalDeviceDriverProperties driver_props;
 5356    vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
 5357    props2.pNext = &props3;
 5358    props3.pNext = &subgroup_props;
 5359    subgroup_props.pNext = &driver_props;
 5360
 5361    // Pointer to the last chain element
 5362    VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
 5363
 5364    if (integer_dot_product) {
 5365        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
 5366        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
 5367    }
 5368
 5369    physical_device.getProperties2(&props2);
 5370
 5371    VkPhysicalDeviceFeatures2 device_features2;
 5372    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
 5373    device_features2.pNext = nullptr;
 5374
 5375    VkPhysicalDeviceVulkan11Features vk11_features;
 5376    vk11_features.pNext = nullptr;
 5377    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
 5378    device_features2.pNext = &vk11_features;
 5379
 5380    VkPhysicalDeviceVulkan12Features vk12_features;
 5381    vk12_features.pNext = nullptr;
 5382    vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
 5383    vk11_features.pNext = &vk12_features;
 5384
 5385    // Pointer to the last chain element
 5386    last_struct = (VkBaseOutStructure *)&vk12_features;
 5387
 5388#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 5389    VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
 5390    coopmat_features.pNext = nullptr;
 5391    coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
 5392    coopmat_features.cooperativeMatrix = VK_FALSE;
 5393
 5394    if (coopmat_support) {
 5395        last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
 5396        last_struct = (VkBaseOutStructure *)&coopmat_features;
 5397    }
 5398#endif
 5399
 5400    VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
 5401    shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
 5402    if (integer_dot_product) {
 5403        last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
 5404        last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
 5405    }
 5406
 5407#if defined(VK_KHR_shader_bfloat16)
 5408    VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
 5409    bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
 5410    if (bfloat16_support) {
 5411        last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
 5412        last_struct = (VkBaseOutStructure *)&bfloat16_features;
 5413    }
 5414#endif
 5415
 5416    vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
 5417
 5418    fp16 = fp16 && vk12_features.shaderFloat16;
 5419
 5420#if defined(VK_KHR_shader_bfloat16)
 5421    bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
 5422#else
 5423    bool bf16 = false;
 5424#endif
 5425
 5426    uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
 5427    const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
 5428    const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
 5429
 5430    integer_dot_product = integer_dot_product
 5431                       && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
 5432                       && shader_integer_dot_product_features.shaderIntegerDotProduct;
 5433
 5434    coopmat_support = coopmat_support
 5435#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
 5436                   && coopmat_features.cooperativeMatrix
 5437#endif
 5438                   && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
 5439
 5440    std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
 5441
 5442    std::string device_name = props2.properties.deviceName.data();
 5443    GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
 5444              idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
 5445              props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
 5446
 5447    if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
 5448        GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
 5449    }
 5450}
 5451
 5452static bool ggml_vk_instance_layer_settings_available();
 5453static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
 5454static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
 5455static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
 5456
 5457static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
 5458DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
 5459    return ggml_vk_default_dispatcher_instance;
 5460}
 5461
 5462static void ggml_vk_instance_init() {
 5463    if (vk_instance_initialized) {
 5464        return;
 5465    }
 5466    VK_LOG_DEBUG("ggml_vk_instance_init()");
 5467
 5468    // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
 5469    ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr);
 5470
 5471    uint32_t api_version = vk::enumerateInstanceVersion();
 5472
 5473    if (api_version < VK_API_VERSION_1_2) {
 5474        std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
 5475        throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required");
 5476    }
 5477
 5478    vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
 5479
 5480    const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
 5481    const bool layer_settings = ggml_vk_instance_layer_settings_available();
 5482#ifdef __APPLE__
 5483    const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
 5484#endif
 5485    const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
 5486    std::vector<const char*> layers;
 5487
 5488    if (layer_settings) {
 5489        layers.push_back("VK_LAYER_KHRONOS_validation");
 5490    }
 5491    std::vector<const char*> extensions;
 5492    if (layer_settings) {
 5493        extensions.push_back("VK_EXT_layer_settings");
 5494    }
 5495#ifdef __APPLE__
 5496    if (portability_enumeration_ext) {
 5497        extensions.push_back("VK_KHR_portability_enumeration");
 5498    }
 5499#endif
 5500    if (debug_utils_ext) {
 5501        extensions.push_back("VK_EXT_debug_utils");
 5502    }
 5503    VkBool32 enable_best_practice = layer_settings;
 5504    std::vector<vk::LayerSettingEXT> settings = {
 5505        {
 5506            "VK_LAYER_KHRONOS_validation",
 5507            "validate_best_practices",
 5508            vk::LayerSettingTypeEXT::eBool32,
 5509            1,
 5510            &enable_best_practice
 5511        },
 5512    };
 5513    vk::LayerSettingsCreateInfoEXT layer_setting_info(settings);
 5514    vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info);
 5515#ifdef __APPLE__
 5516    if (portability_enumeration_ext) {
 5517        instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
 5518    }
 5519#endif
 5520
 5521    vk_instance.instance = vk::createInstance(instance_create_info);
 5522    vk_instance_initialized = true;
 5523
 5524    if (debug_utils_ext) {
 5525        vk_instance.debug_utils_support              = true;
 5526        vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
 5527        vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
 5528        vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
 5529        vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
 5530        vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT =   (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
 5531        vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
 5532    }
 5533
 5534    vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
 5535    vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
 5536    vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
 5537    vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
 5538    const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
 5539
 5540    if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
 5541        vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY);
 5542    }
 5543
 5544    // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
 5545    VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);
 5546
 5547    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
 5548
 5549    // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
 5550    char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
 5551    if (devices_env != nullptr) {
 5552        size_t num_available_devices = devices.size();
 5553
 5554        std::string devices(devices_env);
 5555        std::replace(devices.begin(), devices.end(), ',', ' ');
 5556
 5557        std::stringstream ss(devices);
 5558        size_t tmp;
 5559        while (ss >> tmp) {
 5560            if(tmp >= num_available_devices) {
 5561                std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
 5562                throw std::runtime_error("Invalid Vulkan device index");
 5563            }
 5564            vk_instance.device_indices.push_back(tmp);
 5565        }
 5566    } else {
 5567        // If no vulkan devices are found, return early
 5568        if (devices.empty()) {
 5569            GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
 5570            return;
 5571        }
 5572
 5573        // Default to using all dedicated GPUs
 5574        for (size_t i = 0; i < devices.size(); i++) {
 5575            vk::PhysicalDeviceProperties2 new_props;
 5576            vk::PhysicalDeviceDriverProperties new_driver;
 5577            vk::PhysicalDeviceIDProperties new_id;
 5578            new_props.pNext = &new_driver;
 5579            new_driver.pNext = &new_id;
 5580            devices[i].getProperties2(&new_props);
 5581
 5582            if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
 5583                // Check if there are two physical devices corresponding to the same GPU
 5584                // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
 5585                // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
 5586                // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
 5587                // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
 5588                // driver is MoltenVK
 5589                auto old_device = std::find_if(
 5590                    vk_instance.device_indices.begin(),
 5591                    vk_instance.device_indices.end(),
 5592                    [&devices, &new_id, &new_driver](const size_t k){
 5593                        vk::PhysicalDeviceProperties2 old_props;
 5594                        vk::PhysicalDeviceDriverProperties old_driver;
 5595                        vk::PhysicalDeviceIDProperties old_id;
 5596                        old_props.pNext = &old_driver;
 5597                        old_driver.pNext = &old_id;
 5598                        devices[k].getProperties2(&old_props);
 5599
 5600                        bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
 5601                        same_uuid = same_uuid || (
 5602                            old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
 5603                            std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
 5604                        );
 5605                        bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
 5606
 5607                        return same_uuid && !both_molten_vk;
 5608                    }
 5609                );
 5610                if (old_device == vk_instance.device_indices.end()) {
 5611                    vk_instance.device_indices.push_back(i);
 5612                } else {
 5613                    // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
 5614                    // This can cause error when splitting layers aross the devices, need to keep only 1
 5615                    VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
 5616
 5617                    vk::PhysicalDeviceProperties2 old_props;
 5618                    vk::PhysicalDeviceDriverProperties old_driver;
 5619                    old_props.pNext = &old_driver;
 5620                    devices[*old_device].getProperties2(&old_props);
 5621
 5622                    std::map<vk::DriverId, int> driver_priorities {};
 5623                    int old_priority = std::numeric_limits<int>::max();
 5624                    int new_priority = std::numeric_limits<int>::max();
 5625
 5626                    // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
 5627                    // Smaller number -> higher priority
 5628                    switch (old_props.properties.vendorID) {
 5629                        case VK_VENDOR_ID_AMD:
 5630                            driver_priorities[vk::DriverId::eMesaRadv] = 1;
 5631                            driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
 5632                            driver_priorities[vk::DriverId::eAmdProprietary] = 3;
 5633                            break;
 5634                        case VK_VENDOR_ID_INTEL:
 5635                            driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
 5636                            driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
 5637                            break;
 5638                        case VK_VENDOR_ID_NVIDIA:
 5639                            driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
 5640#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
 5641                            driver_priorities[vk::DriverId::eMesaNvk] = 2;
 5642#endif
 5643                            break;
 5644                    }
 5645                    driver_priorities[vk::DriverId::eMesaDozen] = 100;
 5646
 5647                    if (driver_priorities.count(old_driver.driverID)) {
 5648                        old_priority = driver_priorities[old_driver.driverID];
 5649                    }
 5650                    if (driver_priorities.count(new_driver.driverID)) {
 5651                        new_priority = driver_priorities[new_driver.driverID];
 5652                    }
 5653
 5654                    if (new_priority < old_priority) {
 5655                        auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
 5656                        vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
 5657                        vk_instance.device_indices.push_back(i);
 5658
 5659                        VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
 5660                    }
 5661                    else {
 5662                        VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
 5663                    }
 5664                }
 5665            }
 5666        }
 5667
 5668        // If no GPUs found, fall back to the first non-CPU device.
 5669        // If only CPU devices are available, return without devices.
 5670        if (vk_instance.device_indices.empty()) {
 5671            for (size_t i = 0; i < devices.size(); i++) {
 5672                if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) {
 5673                    vk_instance.device_indices.push_back(i);
 5674                    break;
 5675                }
 5676            }
 5677        }
 5678
 5679        if (vk_instance.device_indices.empty()) {
 5680            GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
 5681            return;
 5682        }
 5683    }
 5684    GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
 5685
 5686    for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
 5687        vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
 5688        std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
 5689
 5690        bool membudget_supported = false;
 5691        for (const auto & ext : extensionprops) {
 5692            if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) {
 5693                membudget_supported = true;
 5694                break;
 5695            }
 5696        }
 5697
 5698        vk_instance.device_supports_membudget.push_back(membudget_supported);
 5699
 5700        ggml_vk_print_gpu_info(i);
 5701    }
 5702}
 5703
 5704static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
 5705    VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
 5706    ggml_vk_instance_init();
 5707    GGML_ASSERT(idx < vk_instance.device_indices.size());
 5708
 5709    ctx->name = GGML_VK_NAME + std::to_string(idx);
 5710
 5711    ctx->device = ggml_vk_get_device(idx);
 5712
 5713    ctx->semaphore_idx = 0;
 5714    ctx->event_idx = 0;
 5715
 5716    ctx->prealloc_size_x = 0;
 5717    ctx->prealloc_size_y = 0;
 5718    ctx->prealloc_size_split_k = 0;
 5719    // Fixed size of 1KB, for deterministic behavior
 5720    ctx->prealloc_size_add_rms_partials = 1024;
 5721
 5722    ctx->fence = ctx->device->device.createFence({});
 5723    ctx->almost_ready_fence = ctx->device->device.createFence({});
 5724
 5725    ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
 5726
 5727    if (vk_perf_logger_enabled) {
 5728        ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
 5729    }
 5730
 5731#ifdef GGML_VULKAN_CHECK_RESULTS
 5732    const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
 5733    vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
 5734    const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
 5735    vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
 5736#endif
 5737}
 5738
 5739static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
 5740    VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
 5741    switch (type) {
 5742        case GGML_TYPE_F32:
 5743        case GGML_TYPE_Q4_0:
 5744        case GGML_TYPE_Q4_1:
 5745        case GGML_TYPE_Q5_0:
 5746        case GGML_TYPE_Q5_1:
 5747        case GGML_TYPE_Q8_0:
 5748        case GGML_TYPE_Q2_K:
 5749        case GGML_TYPE_Q3_K:
 5750        case GGML_TYPE_Q4_K:
 5751        case GGML_TYPE_Q5_K:
 5752        case GGML_TYPE_Q6_K:
 5753        case GGML_TYPE_IQ1_S:
 5754        case GGML_TYPE_IQ1_M:
 5755        case GGML_TYPE_IQ2_XXS:
 5756        case GGML_TYPE_IQ2_XS:
 5757        case GGML_TYPE_IQ2_S:
 5758        case GGML_TYPE_IQ3_XXS:
 5759        case GGML_TYPE_IQ3_S:
 5760        case GGML_TYPE_IQ4_XS:
 5761        case GGML_TYPE_IQ4_NL:
 5762        case GGML_TYPE_MXFP4:
 5763            break;
 5764        default:
 5765            return nullptr;
 5766    }
 5767
 5768    return ctx->device->pipeline_dequant[type];
 5769}
 5770
 5771static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
 5772    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
 5773    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
 5774        return ctx->device->pipeline_matmul_f32;
 5775    }
 5776    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
 5777        return ctx->device->pipeline_matmul_f32_f16;
 5778    }
 5779    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
 5780        return ctx->device->pipeline_matmul_bf16;
 5781    }
 5782    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
 5783        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
 5784            return ctx->device->pipeline_matmul_f16_f32.f16acc;
 5785        }
 5786        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
 5787            return ctx->device->pipeline_matmul_f16.f16acc;
 5788        }
 5789    } else {
 5790        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
 5791            return ctx->device->pipeline_matmul_f16_f32.f32acc;
 5792        }
 5793        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
 5794            return ctx->device->pipeline_matmul_f16.f32acc;
 5795        }
 5796    }
 5797
 5798    // MMQ
 5799    if (src1_type == GGML_TYPE_Q8_1) {
 5800        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
 5801
 5802        if (pipelines->is_empty()) {
 5803            return nullptr;
 5804        }
 5805
 5806        return pipelines;
 5807    }
 5808
 5809    if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
 5810        return nullptr;
 5811    }
 5812
 5813    switch (src0_type) {
 5814        case GGML_TYPE_Q4_0:
 5815        case GGML_TYPE_Q4_1:
 5816        case GGML_TYPE_Q5_0:
 5817        case GGML_TYPE_Q5_1:
 5818        case GGML_TYPE_Q8_0:
 5819        case GGML_TYPE_Q2_K:
 5820        case GGML_TYPE_Q3_K:
 5821        case GGML_TYPE_Q4_K:
 5822        case GGML_TYPE_Q5_K:
 5823        case GGML_TYPE_Q6_K:
 5824        case GGML_TYPE_IQ1_S:
 5825        case GGML_TYPE_IQ1_M:
 5826        case GGML_TYPE_IQ2_XXS:
 5827        case GGML_TYPE_IQ2_XS:
 5828        case GGML_TYPE_IQ2_S:
 5829        case GGML_TYPE_IQ3_XXS:
 5830        case GGML_TYPE_IQ3_S:
 5831        case GGML_TYPE_IQ4_XS:
 5832        case GGML_TYPE_IQ4_NL:
 5833        case GGML_TYPE_MXFP4:
 5834            break;
 5835        default:
 5836            return nullptr;
 5837    }
 5838
 5839    if (ctx->device->coopmat2) {
 5840        assert(src1_type == GGML_TYPE_F16);
 5841        return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
 5842    }
 5843    if (ctx->device->coopmat_support) {
 5844        return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
 5845    }
 5846    return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
 5847}
 5848
 5849static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
 5850    VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
 5851    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);
 5852    GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
 5853
 5854    if (b_type == GGML_TYPE_Q8_1) {
 5855        switch (a_type) {
 5856            case GGML_TYPE_Q4_0:
 5857            case GGML_TYPE_Q4_1:
 5858            case GGML_TYPE_Q5_0:
 5859            case GGML_TYPE_Q5_1:
 5860            case GGML_TYPE_Q8_0:
 5861            case GGML_TYPE_MXFP4:
 5862            case GGML_TYPE_Q2_K:
 5863            case GGML_TYPE_Q3_K:
 5864            case GGML_TYPE_Q4_K:
 5865            case GGML_TYPE_Q5_K:
 5866            case GGML_TYPE_Q6_K:
 5867            case GGML_TYPE_IQ1_S:
 5868            case GGML_TYPE_IQ1_M:
 5869                break;
 5870            default:
 5871                return nullptr;
 5872        }
 5873    }
 5874
 5875    switch (a_type) {
 5876        case GGML_TYPE_F32:
 5877        case GGML_TYPE_F16:
 5878        case GGML_TYPE_BF16:
 5879        case GGML_TYPE_Q4_0:
 5880        case GGML_TYPE_Q4_1:
 5881        case GGML_TYPE_Q5_0:
 5882        case GGML_TYPE_Q5_1:
 5883        case GGML_TYPE_Q8_0:
 5884        case GGML_TYPE_Q2_K:
 5885        case GGML_TYPE_Q3_K:
 5886        case GGML_TYPE_Q4_K:
 5887        case GGML_TYPE_Q5_K:
 5888        case GGML_TYPE_Q6_K:
 5889        case GGML_TYPE_IQ1_S:
 5890        case GGML_TYPE_IQ1_M:
 5891        case GGML_TYPE_IQ2_XXS:
 5892        case GGML_TYPE_IQ2_XS:
 5893        case GGML_TYPE_IQ2_S:
 5894        case GGML_TYPE_IQ3_XXS:
 5895        case GGML_TYPE_IQ3_S:
 5896        case GGML_TYPE_IQ4_XS:
 5897        case GGML_TYPE_IQ4_NL:
 5898        case GGML_TYPE_MXFP4:
 5899            break;
 5900        default:
 5901            return nullptr;
 5902    }
 5903
 5904    // heuristic to choose workgroup size
 5905    uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
 5906    if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
 5907        // Prefer larger workgroups when M is small, to spread the work out more
 5908        // and keep more SMs busy.
 5909        // q6_k seems to prefer small workgroup size even for "medium" values of M.
 5910        if (a_type == GGML_TYPE_Q6_K) {
 5911            if (m < 4096 && k >= 1024) {
 5912                dmmv_wg = DMMV_WG_SIZE_LARGE;
 5913            }
 5914        } else {
 5915            if (m <= 8192 && k >= 1024) {
 5916                dmmv_wg = DMMV_WG_SIZE_LARGE;
 5917            }
 5918        }
 5919    }
 5920
 5921    if (b_type == GGML_TYPE_Q8_1) {
 5922        if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
 5923            dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
 5924        }
 5925        return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];
 5926    }
 5927
 5928    return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
 5929}
 5930
 5931static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
 5932    VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
 5933    if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
 5934        return ctx->device->pipeline_matmul_id_f32;
 5935    }
 5936    if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
 5937        return ctx->device->pipeline_matmul_id_bf16;
 5938    }
 5939    if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
 5940        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
 5941            return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
 5942        }
 5943        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
 5944            return ctx->device->pipeline_matmul_id_f16.f16acc;
 5945        }
 5946    } else {
 5947        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
 5948            return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
 5949        }
 5950        if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
 5951            return ctx->device->pipeline_matmul_id_f16.f32acc;
 5952        }
 5953    }
 5954
 5955    // MMQ
 5956    if (src1_type == GGML_TYPE_Q8_1) {
 5957        vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
 5958
 5959        if (pipelines->is_empty()) {
 5960            return nullptr;
 5961        }
 5962
 5963        return pipelines;
 5964    }
 5965
 5966    GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
 5967
 5968    switch (src0_type) {
 5969        case GGML_TYPE_Q4_0:
 5970        case GGML_TYPE_Q4_1:
 5971        case GGML_TYPE_Q5_0:
 5972        case GGML_TYPE_Q5_1:
 5973        case GGML_TYPE_Q8_0:
 5974        case GGML_TYPE_Q2_K:
 5975        case GGML_TYPE_Q3_K:
 5976        case GGML_TYPE_Q4_K:
 5977        case GGML_TYPE_Q5_K:
 5978        case GGML_TYPE_Q6_K:
 5979        case GGML_TYPE_IQ1_S:
 5980        case GGML_TYPE_IQ1_M:
 5981        case GGML_TYPE_IQ2_XXS:
 5982        case GGML_TYPE_IQ2_XS:
 5983        case GGML_TYPE_IQ2_S:
 5984        case GGML_TYPE_IQ3_XXS:
 5985        case GGML_TYPE_IQ3_S:
 5986        case GGML_TYPE_IQ4_XS:
 5987        case GGML_TYPE_IQ4_NL:
 5988        case GGML_TYPE_MXFP4:
 5989            break;
 5990        default:
 5991            return nullptr;
 5992    }
 5993
 5994    vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
 5995    // XXX TODO 'prec' is not actually allowed in mul_mat_id.
 5996    bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
 5997    bool support_fp16acc = !mmp.f16acc->is_empty();
 5998    bool support_fp32acc = !mmp.f32acc->is_empty();
 5999
 6000    if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
 6001        return mmp.f16acc;
 6002    } else {
 6003        GGML_ASSERT(support_fp32acc);
 6004        return mmp.f32acc;
 6005    }
 6006}
 6007
 6008static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
 6009    VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
 6010    GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
 6011
 6012    if (b_type == GGML_TYPE_Q8_1) {
 6013        switch (a_type) {
 6014            case GGML_TYPE_Q4_0:
 6015            case GGML_TYPE_Q4_1:
 6016            case GGML_TYPE_Q5_0:
 6017            case GGML_TYPE_Q5_1:
 6018            case GGML_TYPE_Q8_0:
 6019            case GGML_TYPE_MXFP4:
 6020            case GGML_TYPE_Q2_K:
 6021            case GGML_TYPE_Q3_K:
 6022            case GGML_TYPE_Q4_K:
 6023            case GGML_TYPE_Q5_K:
 6024            case GGML_TYPE_Q6_K:
 6025            case GGML_TYPE_IQ1_S:
 6026            case GGML_TYPE_IQ1_M:
 6027                break;
 6028            default:
 6029                return nullptr;
 6030        }
 6031    }
 6032
 6033    switch (a_type) {
 6034        case GGML_TYPE_F32:
 6035        case GGML_TYPE_F16:
 6036        case GGML_TYPE_BF16:
 6037        case GGML_TYPE_Q4_0:
 6038        case GGML_TYPE_Q4_1:
 6039        case GGML_TYPE_Q5_0:
 6040        case GGML_TYPE_Q5_1:
 6041        case GGML_TYPE_Q8_0:
 6042        case GGML_TYPE_Q2_K:
 6043        case GGML_TYPE_Q3_K:
 6044        case GGML_TYPE_Q4_K:
 6045        case GGML_TYPE_Q5_K:
 6046        case GGML_TYPE_Q6_K:
 6047        case GGML_TYPE_IQ1_S:
 6048        case GGML_TYPE_IQ1_M:
 6049        case GGML_TYPE_IQ2_XXS:
 6050        case GGML_TYPE_IQ2_XS:
 6051        case GGML_TYPE_IQ2_S:
 6052        case GGML_TYPE_IQ3_XXS:
 6053        case GGML_TYPE_IQ3_S:
 6054        case GGML_TYPE_IQ4_XS:
 6055        case GGML_TYPE_IQ4_NL:
 6056        case GGML_TYPE_MXFP4:
 6057            break;
 6058        default:
 6059            return nullptr;
 6060    }
 6061
 6062    // heuristic to choose workgroup size
 6063    uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
 6064    if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
 6065        // Prefer larger workgroups when M is small, to spread the work out more
 6066        // and keep more SMs busy.
 6067        // q6_k seems to prefer small workgroup size even for "medium" values of M.
 6068        if (a_type == GGML_TYPE_Q6_K) {
 6069            if (m < 4096 && k >= 1024) {
 6070                dmmv_wg = DMMV_WG_SIZE_LARGE;
 6071            }
 6072        } else {
 6073            if (m <= 8192 && k >= 1024) {
 6074                dmmv_wg = DMMV_WG_SIZE_LARGE;
 6075            }
 6076        }
 6077    }
 6078
 6079    if (b_type == GGML_TYPE_Q8_1) {
 6080        if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
 6081            dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
 6082        }
 6083        return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
 6084    }
 6085
 6086    return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
 6087}
 6088
 6089static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
 6090    VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
 6091    vk_buffer buf = ggml_vk_create_buffer(device, size,
 6092        {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
 6093         vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
 6094
 6095    if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
 6096        fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
 6097            size/1024.0/1024.0);
 6098        device->device.freeMemory(buf->device_memory);
 6099        device->device.destroyBuffer(buf->buffer);
 6100        return nullptr;
 6101    }
 6102
 6103    std::lock_guard<std::recursive_mutex> guard(device->mutex);
 6104    device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
 6105
 6106    return buf->ptr;
 6107}
 6108
 6109static void ggml_vk_host_free(vk_device& device, void* ptr) {
 6110    if (ptr == nullptr) {
 6111        return;
 6112    }
 6113    VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
 6114    std::lock_guard<std::recursive_mutex> guard(device->mutex);
 6115
 6116    vk_buffer buf;
 6117    size_t index;
 6118    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
 6119        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
 6120        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
 6121        if (ptr >= addr && ptr < endr) {
 6122            buf = std::get<2>(device->pinned_memory[i]);
 6123            index = i;
 6124            break;
 6125        }
 6126    }
 6127    if (buf == nullptr) {
 6128        fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
 6129        return;
 6130    }
 6131
 6132    ggml_vk_destroy_buffer(buf);
 6133
 6134    device->pinned_memory.erase(device->pinned_memory.begin() + index);
 6135}
 6136
 6137static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
 6138    std::lock_guard<std::recursive_mutex> guard(device->mutex);
 6139    buf = nullptr;
 6140    buf_offset = 0;
 6141    for (size_t i = 0; i < device->pinned_memory.size(); i++) {
 6142        const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
 6143        const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
 6144        if (ptr >= addr && ptr < endr) {
 6145            buf = std::get<2>(device->pinned_memory[i]);
 6146            buf_offset = ((const uint8_t *)ptr) - addr;
 6147            break;
 6148        }
 6149    }
 6150}
 6151
 6152static vk_subbuffer ggml_vk_tensor_subbuffer(
 6153    const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
 6154
 6155    vk_buffer buffer = nullptr;
 6156    size_t offset = 0;
 6157    if (ctx->device->uma) {
 6158        ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
 6159    }
 6160    if (!buffer) {
 6161        auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
 6162        buffer = buf_ctx->dev_buffer;
 6163        offset = vk_tensor_offset(tensor) + tensor->view_offs;
 6164    }
 6165    GGML_ASSERT(buffer != nullptr);
 6166
 6167    size_t size = ggml_nbytes(tensor);
 6168
 6169    size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
 6170    // The shader must support misaligned offsets when indexing into the buffer
 6171    GGML_ASSERT(allow_misalign || misalign_bytes == 0);
 6172    offset &= ~misalign_bytes;
 6173    size += misalign_bytes;
 6174
 6175    return vk_subbuffer{buffer, offset, size};
 6176}
 6177
 6178static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
 6179    vk_submission s;
 6180    s.buffer = ggml_vk_create_cmd_buffer(device, p);
 6181    if (one_time) {
 6182        s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
 6183    } else {
 6184        s.buffer.begin({ vk::CommandBufferUsageFlags{} });
 6185    }
 6186
 6187    return s;
 6188}
 6189
 6190template <typename T> size_t push_constant_size(const T &t) {
 6191    static_assert(std::is_class<T>::value, "T must be a struct/class");
 6192    GGML_UNUSED(t);
 6193    return sizeof(T);
 6194}
 6195template <typename T> size_t push_constant_size(const std::vector<T> &t) {
 6196    GGML_UNUSED(t);
 6197    return sizeof(T) * t.size();
 6198}
 6199template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
 6200    GGML_UNUSED(t);
 6201    return sizeof(T) * N;
 6202}
 6203
 6204template <typename T> const T *push_constant_data(const T &t) {
 6205    static_assert(std::is_class<T>::value, "T must be a struct/class");
 6206    return &t;
 6207}
 6208template <typename T> const T *push_constant_data(const std::vector<T> &t) {
 6209    return t.data();
 6210}
 6211template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
 6212    return t.data();
 6213}
 6214
 6215template <typename T>
 6216static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
 6217    const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
 6218    const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
 6219    const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
 6220    VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
 6221    for (auto& buffer : descriptor_buffer_infos) {
 6222        std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
 6223    }
 6224    std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
 6225    GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
 6226                wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
 6227                wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
 6228    GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
 6229    GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
 6230    GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
 6231    GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));
 6232
 6233    vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
 6234    vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
 6235    ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
 6236
 6237    subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
 6238    subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
 6239    subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
 6240                                pipeline->layout,
 6241                                0,
 6242                                { descriptor_set },
 6243                                {});
 6244    subctx->s->buffer.dispatch(wg0, wg1, wg2);
 6245}
 6246
 6247static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
 6248    s.buffer.end();
 6249
 6250    s.wait_semaphores = std::move(wait_semaphores);
 6251    s.signal_semaphores = std::move(signal_semaphores);
 6252}
 6253
 6254static void ggml_vk_ctx_end(vk_context& ctx) {
 6255    VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
 6256    if (ctx->s == nullptr) {
 6257        return;
 6258    }
 6259
 6260    ctx->s->buffer.end();
 6261    ctx->s = nullptr;
 6262}
 6263
 6264static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
 6265    VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
 6266    if (subctx->s != nullptr) {
 6267        ggml_vk_ctx_end(subctx);
 6268    }
 6269
 6270    subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) });
 6271    subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
 6272}
 6273
 6274static size_t ggml_vk_align_size(size_t width, size_t align) {
 6275    VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
 6276    return CEIL_DIV(width, align) * align;
 6277}
 6278
 6279static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {
 6280    if (memcpys == nullptr) {
 6281        memcpy(dst, src, size);
 6282    } else {
 6283        memcpys->emplace_back(dst, src, size);
 6284    }
 6285}
 6286
 6287static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {
 6288    if (memsets == nullptr) {
 6289        memset(dst, val, size);
 6290    } else {
 6291        memsets->emplace_back(dst, val, size);
 6292    }
 6293}
 6294
 6295static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
 6296    if (device->sync_staging == nullptr || device->sync_staging->size < size) {
 6297        VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
 6298        ggml_vk_destroy_buffer(device->sync_staging);
 6299        device->sync_staging = ggml_vk_create_buffer_check(device, size,
 6300            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
 6301            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
 6302    }
 6303}
 6304
 6305static void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
 6306    if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
 6307        VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
 6308        ggml_vk_destroy_buffer(ctx->sync_staging);
 6309        ctx->sync_staging = ggml_vk_create_buffer_check(ctx->device, size,
 6310            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
 6311            vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
 6312    }
 6313}
 6314
 6315static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
 6316    VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
 6317    GGML_ASSERT(!ggml_is_contiguous(tensor));
 6318    // Buffer is already mapped
 6319    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
 6320        std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
 6321        GGML_ABORT("fatal error");
 6322    }
 6323    // Check if src is pinned memory
 6324    vk_buffer buf = nullptr;
 6325    size_t buf_offset = 0;
 6326    ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
 6327
 6328    const uint64_t ne0 = tensor->ne[0];
 6329    const uint64_t ne1 = tensor->ne[1];
 6330    const uint64_t ne2 = tensor->ne[2];
 6331    const uint64_t ne3 = tensor->ne[3];
 6332    const uint64_t nb0 = tensor->nb[0];
 6333    const uint64_t nb1 = tensor->nb[1];
 6334    const uint64_t nb2 = tensor->nb[2];
 6335    const uint64_t nb3 = tensor->nb[3];
 6336    const ggml_type type = tensor->type;
 6337    const uint64_t ts = ggml_type_size(type);
 6338    const uint64_t bs = ggml_blck_size(type);
 6339
 6340    const uint64_t dstnb0 = ts;
 6341    const uint64_t dstnb1 = dstnb0*(ne0/bs);
 6342    const uint64_t dstnb2 = dstnb1*ne1;
 6343    const uint64_t dstnb3 = dstnb2*ne2;
 6344
 6345    const uint64_t ne = ggml_nelements(tensor);
 6346
 6347    if (buf != nullptr) {
 6348        // Memory is pinned, use as staging buffer
 6349        std::vector<vk::BufferCopy> slices;
 6350
 6351        for (uint64_t i3 = 0; i3 < ne3; i3++) {
 6352            for (uint64_t i2 = 0; i2 < ne2; i2++) {
 6353                // Find longest contiguous slice
 6354                if (ne1*nb1 == dstnb2) {
 6355                    slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
 6356                } else {
 6357                    for (uint64_t i1 = 0; i1 < ne1; i1++) {
 6358                        if (ne0*nb0/bs == dstnb1) {
 6359                            slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
 6360                        } else {
 6361                            const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
 6362                            const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
 6363                            for (uint64_t i0 = 0; i0 < ne0; i0++) {
 6364                                slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
 6365                            }
 6366                        }
 6367                    }
 6368                }
 6369            }
 6370        }
 6371
 6372        ggml_vk_sync_buffers(ctx, subctx);
 6373        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
 6374        return;
 6375    }
 6376
 6377    if (!sync_staging) {
 6378        GGML_ABORT("Asynchronous write to non-pinned memory not supported");
 6379    }
 6380
 6381    // Staging buffer required
 6382    vk_buffer& staging = ctx->device->sync_staging;
 6383    const uint64_t copy_size = ts*ne/bs;
 6384    ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
 6385    VkBufferCopy buf_copy{ 0, offset, copy_size };
 6386
 6387    ggml_vk_sync_buffers(ctx, subctx);
 6388    vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
 6389
 6390    for (uint64_t i3 = 0; i3 < ne3; i3++) {
 6391        for (uint64_t i2 = 0; i2 < ne2; i2++) {
 6392            // Find longest contiguous slice
 6393            if (ne1*nb1 == dstnb2) {
 6394                deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
 6395            } else {
 6396                for (uint64_t i1 = 0; i1 < ne1; i1++) {
 6397                    if (ne0*nb0/bs == dstnb1) {
 6398                        deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
 6399                    } else {
 6400                        const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
 6401                        const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
 6402                        for (uint64_t i0 = 0; i0 < ne0; i0++) {
 6403                            deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
 6404                        }
 6405                    }
 6406                }
 6407            }
 6408        }
 6409    }
 6410}
 6411
 6412static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
 6413    VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
 6414    // Check if src is pinned memory
 6415    vk_buffer buf = nullptr;
 6416    size_t buf_offset = 0;
 6417    ggml_vk_host_get(dst->device, src, buf, buf_offset);
 6418
 6419    if (buf != nullptr) {
 6420        // Memory is pinned, use as staging buffer
 6421        std::vector<vk::BufferCopy> slices(1);
 6422        if (width == spitch) {
 6423            // Only do single write if stride is equal
 6424            slices[0].srcOffset = buf_offset;
 6425            slices[0].dstOffset = offset;
 6426            slices[0].size = width * height;
 6427        } else {
 6428            slices.resize(height);
 6429            for (size_t i = 0; i < height; i++) {
 6430                slices[i].srcOffset = buf_offset + i * spitch;
 6431                slices[i].dstOffset = offset + i * width;
 6432                slices[i].size = width;
 6433            }
 6434        }
 6435
 6436        ggml_vk_sync_buffers(nullptr, subctx);
 6437        subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
 6438        return true;
 6439    }
 6440    VK_LOG_DEBUG("STAGING");
 6441
 6442    if (!sync_staging) {
 6443        // copy was not handled caller needs to fall back
 6444        return false;
 6445    }
 6446
 6447    // Staging buffer required
 6448    const size_t copy_size = width*height;
 6449    ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
 6450
 6451    vk_buffer& staging_buffer = dst->device->sync_staging;
 6452
 6453    VkBufferCopy buf_copy = {
 6454        0,
 6455        offset,
 6456        copy_size};
 6457
 6458    ggml_vk_sync_buffers(nullptr, subctx);
 6459    vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
 6460
 6461    if (width == spitch) {
 6462        deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
 6463    } else {
 6464        for (size_t i = 0; i < height; i++) {
 6465            deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
 6466        }
 6467    }
 6468    return true;
 6469}
 6470
 6471static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
 6472    VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
 6473    return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
 6474}
 6475
 6476static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
 6477    VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
 6478    // Buffer is already mapped
 6479    if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
 6480        GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
 6481
 6482        for (size_t i = 0; i < height; i++) {
 6483            memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
 6484        }
 6485    } else {
 6486        std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
 6487
 6488        vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
 6489        ggml_vk_ctx_begin(dst->device, subctx);
 6490        bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
 6491        GGML_ASSERT(ret);
 6492        ggml_vk_ctx_end(subctx);
 6493
 6494        for (auto& cpy : subctx->in_memcpys) {
 6495            memcpy(cpy.dst, cpy.src, cpy.n);
 6496        }
 6497
 6498        for (auto& mset : subctx->memsets) {
 6499            memset(mset.dst, mset.val, mset.n);
 6500        }
 6501
 6502        ggml_vk_submit(subctx, dst->device->fence);
 6503        VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
 6504        dst->device->device.resetFences({ dst->device->fence });
 6505        ggml_vk_queue_command_pools_cleanup(dst->device);
 6506    }
 6507}
 6508
 6509static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
 6510    VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
 6511    ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
 6512}
 6513
 6514static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
 6515    VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
 6516    GGML_ASSERT(width > 0);
 6517    GGML_ASSERT(height > 0);
 6518    GGML_ASSERT(src != nullptr);
 6519
 6520    // TODO: staging_offset is not used
 6521
 6522    // Check if dst is pinned memory
 6523    vk_buffer buf = nullptr;
 6524    size_t buf_offset = 0;
 6525    ggml_vk_host_get(src->device, dst, buf, buf_offset);
 6526
 6527    std::vector<vk::BufferCopy> slices(1);
 6528    if (width == spitch && width == dpitch) {
 6529        // Only do single write if stride is equal
 6530        slices[0].srcOffset = offset;
 6531        slices[0].dstOffset = buf_offset;
 6532        slices[0].size = width * height;
 6533    } else {
 6534        slices.resize(height);
 6535        for (size_t i = 0; i < height; i++) {
 6536            slices[i].srcOffset = offset + i * spitch;
 6537            slices[i].dstOffset = buf_offset + i * dpitch;
 6538            slices[i].size = width;
 6539        }
 6540    }
 6541
 6542    if (buf != nullptr) {
 6543        // Memory is pinned, use as staging buffer
 6544        ggml_vk_sync_buffers(nullptr, subctx);
 6545        subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
 6546
 6547        return true;
 6548    }
 6549    VK_LOG_DEBUG("STAGING");
 6550
 6551    if (!sync_staging) {
 6552        // copy was not handled caller needs to fall back
 6553        return false;
 6554    }
 6555
 6556    // Fall back to staging buffer
 6557    const size_t copy_size = dpitch * height;
 6558    ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
 6559
 6560    vk_buffer& staging_buffer = src->device->sync_staging;
 6561
 6562    ggml_vk_sync_buffers(nullptr, subctx);
 6563    subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
 6564
 6565    deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
 6566    return true;
 6567}
 6568
 6569static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
 6570    return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
 6571}
 6572
 6573static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
 6574    VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
 6575
 6576    // If the device is not an UMA device the memory is host-accessible through rebar. While writing
 6577    // through PCIe is sufficient fast reading back data from PCIe is slower than going through
 6578    // the HW device to host copy path.
 6579    if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
 6580        GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
 6581
 6582        memcpy(dst, (uint8_t *) src->ptr + offset, size);
 6583    } else {
 6584        std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
 6585
 6586        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
 6587        ggml_vk_ctx_begin(src->device, subctx);
 6588        bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
 6589        GGML_ASSERT(ret);
 6590        ggml_vk_ctx_end(subctx);
 6591
 6592        ggml_vk_submit(subctx, src->device->fence);
 6593        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
 6594        src->device->device.resetFences({ src->device->fence });
 6595        ggml_vk_queue_command_pools_cleanup(src->device);
 6596
 6597        for (auto& cpy : subctx->out_memcpys) {
 6598            memcpy(cpy.dst, cpy.src, cpy.n);
 6599        }
 6600    }
 6601}
 6602
 6603static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
 6604    VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
 6605    // Make sure both buffers are on same device
 6606    GGML_ASSERT(src->device == dst->device);
 6607
 6608    VkBufferCopy bc{ src_offset, dst_offset, size };
 6609
 6610    vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
 6611}
 6612
 6613static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
 6614    if (src->device == dst->device) {
 6615        std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
 6616        VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
 6617        // Copy within the device
 6618        vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
 6619        ggml_vk_ctx_begin(src->device, subctx);
 6620        ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
 6621        ggml_vk_ctx_end(subctx);
 6622        ggml_vk_submit(subctx, src->device->fence);
 6623        VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
 6624        src->device->device.resetFences({ src->device->fence });
 6625        ggml_vk_queue_command_pools_cleanup(src->device);
 6626    } else {
 6627        VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
 6628        // Copy device to device
 6629        ggml_vk_ensure_sync_staging_buffer(src->device, size);
 6630
 6631        // Copy to src staging buffer
 6632        ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
 6633        // Copy to dst buffer
 6634        ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
 6635    }
 6636}
 6637
 6638static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
 6639    VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
 6640
 6641    if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
 6642        dst->device->uma) {
 6643        deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
 6644        return;
 6645    }
 6646
 6647    // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
 6648    ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
 6649}
 6650
 6651static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
 6652    VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
 6653
 6654    if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
 6655        dst->device->uma) {
 6656        memset((uint8_t*)dst->ptr + offset, c, size);
 6657        return;
 6658    }
 6659
 6660    std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
 6661    vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
 6662    ggml_vk_ctx_begin(dst->device, subctx);
 6663    subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
 6664    ggml_vk_ctx_end(subctx);
 6665
 6666    ggml_vk_submit(subctx, dst->device->fence);
 6667    VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
 6668    dst->device->device.resetFences({ dst->device->fence });
 6669    ggml_vk_queue_command_pools_cleanup(dst->device);
 6670}
 6671
 6672static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
 6673    VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
 6674
 6675    if (disable_split_k) {
 6676        return 1;
 6677    }
 6678
 6679    uint32_t split_k = 1;
 6680    if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
 6681        // If k is 'large' and the SMs will fill less than halfway, use split_k.
 6682        uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
 6683        uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
 6684
 6685        if (k >= 2048) {
 6686            if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
 6687                split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
 6688            } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
 6689                split_k = 3;
 6690            }
 6691            // Cap the split at 8x. Unless k is huge this is a lot of overhead.
 6692            split_k = std::min(split_k, 8u);
 6693
 6694            // ggml_vk_matmul will align the splits to be a multiple of 256.
 6695            // If this rounded up size would cause the last split to be empty,
 6696            // then reduce the split count.
 6697            while (true) {
 6698                if (split_k == 1) {
 6699                    break;
 6700                }
 6701                uint32_t k_split = CEIL_DIV(k, split_k);
 6702                k_split = ROUNDUP_POW2(k_split, 256);
 6703                if (k_split * (split_k - 1) < k) {
 6704                    break;
 6705                }
 6706                split_k--;
 6707            }
 6708        }
 6709    }
 6710
 6711    return split_k;
 6712}
 6713
 6714static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
 6715    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
 6716
 6717    if (ctx->device->coopmat2) {
 6718        const uint32_t shader_core_count = ctx->device->shader_core_count;
 6719        const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
 6720        const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
 6721
 6722        // Use large shader when the N dimension is greater than the medium shader's tile size
 6723        uint32_t crossover_large = mmp->m->wg_denoms[1];
 6724
 6725        // Prefer large over medium if either:
 6726        // - medium or large tiles would overfill the GPU
 6727        // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
 6728        //   (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
 6729        bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
 6730                            // split_k==3 with large tiles likely better than medium tiles with no split_k.
 6731                            (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
 6732
 6733        if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
 6734            return aligned ? mmp->a_l : mmp->l;
 6735        }
 6736        // Use medium shader when the N dimension is greater than the small shader's tile size
 6737        uint32_t crossover_medium = mmp->s->wg_denoms[1];
 6738        if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
 6739            return aligned ? mmp->a_m : mmp->m;
 6740        }
 6741        return aligned ? mmp->a_s : mmp->s;
 6742    }
 6743
 6744    if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
 6745        return aligned ? mmp->a_s : mmp->s;
 6746    }
 6747    if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
 6748        return aligned ? mmp->a_m : mmp->m;
 6749    }
 6750    return aligned ? mmp->a_l : mmp->l;
 6751
 6752    GGML_UNUSED(src1_type);
 6753}
 6754
 6755static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
 6756    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
 6757    return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
 6758}
 6759
 6760static void ggml_vk_matmul(
 6761        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
 6762        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
 6763        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
 6764        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
 6765        uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
 6766        uint32_t padded_n) {
 6767        VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
 6768    if (split_k == 1) {
 6769        const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
 6770        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
 6771        return;
 6772    }
 6773
 6774    if (ctx->prealloc_split_k_need_sync) {
 6775        ggml_vk_sync_buffers(ctx, subctx);
 6776    }
 6777
 6778    GGML_ASSERT(batch_stride_d == m * n);
 6779
 6780    // Round the split size up to a multiple of 256 (k-quant alignment)
 6781    uint32_t k_split = CEIL_DIV(k, split_k);
 6782    k_split = ROUNDUP_POW2(k_split, 256);
 6783
 6784    const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
 6785    // Make sure enough workgroups get assigned for split k to work
 6786    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
 6787    ggml_vk_sync_buffers(ctx, subctx);
 6788    const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
 6789    ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
 6790    ctx->prealloc_split_k_need_sync = true;
 6791}
 6792
 6793static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
 6794    VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
 6795
 6796    if (ctx->device->coopmat2) {
 6797        // Use large shader when the N dimension is greater than the medium shader's tile size
 6798        uint32_t crossover_large = mmp->m->wg_denoms[1];
 6799        if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
 6800            return aligned ? mmp->a_l : mmp->l;
 6801        }
 6802        // Use medium shader when the N dimension is greater than the small shader's tile size
 6803        uint32_t crossover_medium = mmp->s->wg_denoms[1];
 6804        if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
 6805            return aligned ? mmp->a_m : mmp->m;
 6806        }
 6807        return aligned ? mmp->a_s : mmp->s;
 6808    }
 6809
 6810    if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
 6811        return aligned ? mmp->a_s : mmp->s;
 6812    }
 6813    if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
 6814        return aligned ? mmp->a_m : mmp->m;
 6815    }
 6816    return aligned ? mmp->a_l : mmp->l;
 6817}
 6818
 6819static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
 6820    VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
 6821    return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
 6822}
 6823
 6824static void ggml_vk_matmul_id(
 6825        ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
 6826        vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
 6827        uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
 6828        uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
 6829        uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
 6830        uint32_t padded_n) {
 6831    VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
 6832        "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
 6833        "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
 6834        "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
 6835    const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
 6836                                              nei0, nei1, nbi1, ne11, padded_n };
 6837    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
 6838}
 6839
 6840static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
 6841    return
 6842        tensor->nb[0] == ggml_type_size(tensor->type) &&
 6843        tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
 6844        (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
 6845}
 6846
 6847static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
 6848
 6849    // Choose "contiguous copy" shader if src/dst are contiguous
 6850    bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
 6851
 6852    // Use optimized "transpose" shader if src dim1 is the innermost dimension.
 6853    bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);
 6854
 6855    if (transpose && src->type == to) {
 6856        if (ggml_type_size(to) == 4) {
 6857            return ctx->device->pipeline_cpy_transpose_32;
 6858        } else if (ggml_type_size(to) == 2) {
 6859            return ctx->device->pipeline_cpy_transpose_16;
 6860        }
 6861    }
 6862
 6863    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
 6864        if (contig) {
 6865            return ctx->device->pipeline_contig_cpy_f32_f32;
 6866        } else {
 6867            return ctx->device->pipeline_cpy_f32_f32;
 6868        }
 6869    }
 6870    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
 6871        if (contig) {
 6872            return ctx->device->pipeline_contig_cpy_f32_f16;
 6873        } else {
 6874            return ctx->device->pipeline_cpy_f32_f16;
 6875        }
 6876    }
 6877    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
 6878        if (contig) {
 6879            return ctx->device->pipeline_contig_cpy_f16_f16;
 6880        } else {
 6881            return ctx->device->pipeline_cpy_f16_f16;
 6882        }
 6883    }
 6884    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
 6885        if (contig) {
 6886            return ctx->device->pipeline_contig_cpy_f16_f32;
 6887        } else {
 6888            return ctx->device->pipeline_cpy_f16_f32;
 6889        }
 6890    }
 6891    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
 6892        if (contig) {
 6893            return ctx->device->pipeline_contig_cpy_f32_bf16;
 6894        } else {
 6895            return ctx->device->pipeline_cpy_f32_bf16;
 6896        }
 6897    }
 6898    if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
 6899        if (contig) {
 6900            return ctx->device->pipeline_contig_cpy_f32_i32;
 6901        } else {
 6902            return ctx->device->pipeline_cpy_f32_i32;
 6903        }
 6904    }
 6905    if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {
 6906        if (contig) {
 6907            return ctx->device->pipeline_contig_cpy_i32_f32;
 6908        } else {
 6909            return ctx->device->pipeline_cpy_i32_f32;
 6910        }
 6911    }
 6912    if (src->type == GGML_TYPE_F32) {
 6913        switch (to) {
 6914        case GGML_TYPE_Q4_0:
 6915        case GGML_TYPE_Q4_1:
 6916        case GGML_TYPE_Q5_0:
 6917        case GGML_TYPE_Q5_1:
 6918        case GGML_TYPE_Q8_0:
 6919        case GGML_TYPE_IQ4_NL:
 6920            return ctx->device->pipeline_cpy_f32_quant[to];
 6921        default:
 6922            break;
 6923        }
 6924    }
 6925
 6926    if (to == GGML_TYPE_F32) {
 6927        switch (src->type) {
 6928        case GGML_TYPE_Q4_0:
 6929        case GGML_TYPE_Q4_1:
 6930        case GGML_TYPE_Q5_0:
 6931        case GGML_TYPE_Q5_1:
 6932        case GGML_TYPE_Q8_0:
 6933        case GGML_TYPE_IQ4_NL:
 6934            return ctx->device->pipeline_cpy_quant_f32[src->type];
 6935        default:
 6936            break;
 6937        }
 6938    }
 6939
 6940    if (src->type == to) {
 6941        // Copy two or four bytes at a time, depending on block size.
 6942        // For quantized types, we scale by block size/type size. But
 6943        // this path is also used for bf16->bf16 for example, where the
 6944        // type size must be exactly 2 or 4.
 6945        GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
 6946        if ((ggml_type_size(src->type) % 4) == 0) {
 6947            if (contig) {
 6948                return ctx->device->pipeline_contig_cpy_f32_f32;
 6949            } else {
 6950                return ctx->device->pipeline_cpy_f32_f32;
 6951            }
 6952        } else {
 6953            if (contig) {
 6954                return ctx->device->pipeline_contig_cpy_f16_f16;
 6955            } else {
 6956                return ctx->device->pipeline_cpy_f16_f16;
 6957            }
 6958        }
 6959    }
 6960
 6961    std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
 6962    GGML_ABORT("fatal error");
 6963}
 6964
 6965static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, const vk_subbuffer & in, const vk_subbuffer & out) {
 6966    VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
 6967    std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
 6968    const int tensor_type_size = ggml_type_size(tensor->type);
 6969
 6970    const uint32_t ne = ggml_nelements(tensor);
 6971    std::array<uint32_t, 3> elements;
 6972
 6973    if (ne > 262144) {
 6974        elements = { 512, 512, CEIL_DIV(ne, 262144) };
 6975    } else if (ne > 512) {
 6976        elements = { 512, CEIL_DIV(ne, 512), 1 };
 6977    } else {
 6978        elements = { ne, 1, 1 };
 6979    }
 6980
 6981    vk_op_unary_push_constants pc = {
 6982        (uint32_t)ne,
 6983        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
 6984        (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3],                       1                   , (uint32_t)tensor->ne[0]                   , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
 6985        0,
 6986        0.0f, 0.0f,
 6987        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 6988    };
 6989    init_pushconst_fastdiv(pc);
 6990    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
 6991    ggml_vk_sync_buffers(ctx, subctx);
 6992}
 6993
 6994static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
 6995    switch(type) {
 6996        case GGML_TYPE_Q8_1:
 6997            return ctx->device->pipeline_quantize_q8_1_x4;
 6998        default:
 6999            std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
 7000            GGML_ABORT("fatal error");
 7001    }
 7002}
 7003
 7004static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, const vk_subbuffer & in, const vk_subbuffer & out, uint32_t ne) {
 7005    VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
 7006
 7007    vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 7008
 7009    const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
 7010    // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
 7011    const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
 7012    const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
 7013
 7014    const vk_quantize_q8_1_push_constants pc = {
 7015        ne,
 7016        num_blocks,
 7017    };
 7018
 7019    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 });
 7020    ggml_vk_sync_buffers(ctx, subctx);
 7021}
 7022
 7023static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
 7024    GGML_UNUSED(ctx);
 7025#if defined(VK_EXT_shader_64bit_indexing)
 7026    vk_pipeline *ptr = &pipeline;
 7027    while (*ptr) {
 7028        if ((*ptr)->is_64b_indexing) {
 7029            return *ptr;
 7030        }
 7031        ptr = &(*ptr)->next;
 7032    }
 7033#endif
 7034    return pipeline;
 7035}
 7036
 7037static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
 7038    VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 7039    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 7040    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 7041    std::cerr << "))");
 7042    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT
 7043    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
 7044
 7045    const uint64_t ne00 = src0->ne[0];
 7046    const uint64_t ne01 = src0->ne[1];
 7047    const uint64_t ne02 = src0->ne[2];
 7048    const uint64_t ne03 = src0->ne[3];
 7049
 7050    const uint64_t ne10 = src1->ne[0];
 7051    const uint64_t ne11 = src1->ne[1];
 7052    const uint64_t ne12 = src1->ne[2];
 7053    const uint64_t ne13 = src1->ne[3];
 7054
 7055    const uint64_t ne21 = dst->ne[1];
 7056    const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
 7057    const uint32_t stride_batch_d = stride_d*ne21;
 7058
 7059    const uint64_t r2 = ne12 / ne02;
 7060    const uint64_t r3 = ne13 / ne03;
 7061
 7062    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
 7063    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
 7064    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
 7065
 7066    vk_buffer d_Qx = nullptr;
 7067    size_t qx_buf_offset = 0;
 7068    vk_buffer d_Qy = nullptr;
 7069    size_t qy_buf_offset = 0;
 7070
 7071    bool src0_uma = false;
 7072    bool src1_uma = false;
 7073
 7074    if (ctx->device->uma) {
 7075        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
 7076        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
 7077        src0_uma = d_Qx != nullptr;
 7078        src1_uma = d_Qy != nullptr;
 7079    }
 7080
 7081    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
 7082    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
 7083                              !ggml_vk_dim01_contiguous(src0);
 7084    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
 7085                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
 7086                              !ggml_vk_dim01_contiguous(src1);
 7087
 7088    // If src0 is BF16, try to use a BF16 x BF16 multiply
 7089    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
 7090
 7091    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 7092
 7093    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
 7094
 7095    // Check for mmq first
 7096    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
 7097
 7098    if (mmp == nullptr) {
 7099        // Fall back to f16 dequant mul mat
 7100        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
 7101        quantize_y = false;
 7102    }
 7103
 7104    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
 7105    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
 7106
 7107    if (qx_needs_dequant) {
 7108        // Fall back to dequant + f16 mulmat
 7109        mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
 7110    }
 7111
 7112    // Not implemented
 7113    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 7114
 7115    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
 7116    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
 7117
 7118    vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
 7119
 7120    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
 7121        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
 7122    }
 7123
 7124    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
 7125    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
 7126    const uint64_t x_ne = ggml_nelements(src0);
 7127    // 128 elements per Q8_1 x4 block
 7128    const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
 7129    const uint64_t d_ne = ggml_nelements(dst);
 7130
 7131    const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
 7132
 7133    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
 7134    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
 7135    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
 7136    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
 7137    const uint64_t d_sz = sizeof(float) * d_ne;
 7138
 7139    vk_pipeline to_fp16_vk_0 = nullptr;
 7140    vk_pipeline to_fp16_vk_1 = nullptr;
 7141    vk_pipeline to_q8_1 = nullptr;
 7142
 7143    if (x_non_contig) {
 7144        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
 7145    } else {
 7146        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
 7147    }
 7148    if (y_non_contig) {
 7149        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
 7150    } else {
 7151        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
 7152    }
 7153    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
 7154    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 7155
 7156    if (quantize_y) {
 7157        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 7158    }
 7159
 7160    {
 7161        const uint64_t split_k_size = split_k > 1 ? d_sz * split_k : 0;
 7162        if (
 7163                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
 7164                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
 7165                (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
 7166            GGML_ABORT("Requested preallocation size is too large");
 7167        }
 7168        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
 7169            ctx->prealloc_size_x = x_sz;
 7170            ggml_vk_preallocate_buffers(ctx, subctx);
 7171        }
 7172        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
 7173            ctx->prealloc_size_y = y_sz;
 7174            ggml_vk_preallocate_buffers(ctx, subctx);
 7175        }
 7176        if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
 7177            ctx->prealloc_size_split_k = split_k_size;
 7178            ggml_vk_preallocate_buffers(ctx, subctx);
 7179        }
 7180
 7181        // Request descriptor sets
 7182        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 7183        if (qx_needs_dequant) {
 7184            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
 7185        }
 7186        if (qy_needs_dequant) {
 7187            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
 7188        }
 7189        if (quantize_y) {
 7190            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
 7191        }
 7192        if (split_k > 1) {
 7193            ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
 7194        }
 7195    }
 7196
 7197    vk_buffer d_D = dst_buf_ctx->dev_buffer;
 7198    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
 7199    GGML_ASSERT(d_D != nullptr);
 7200    GGML_ASSERT(d_D->size >= d_buf_offset + d_sz);
 7201    vk_buffer d_X;
 7202    uint64_t x_buf_offset = 0;
 7203    vk_buffer d_Y;
 7204    uint64_t y_buf_offset = 0;
 7205    if (!src0_uma) {
 7206        d_Qx = src0_buf_ctx->dev_buffer;
 7207        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
 7208        GGML_ASSERT(d_Qx != nullptr);
 7209    }
 7210    if (!src1_uma) {
 7211        d_Qy = src1_buf_ctx->dev_buffer;
 7212        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
 7213        GGML_ASSERT(d_Qy != nullptr);
 7214    }
 7215    if (qx_needs_dequant) {
 7216        d_X = ctx->prealloc_x;
 7217        GGML_ASSERT(d_X->size >= x_sz);
 7218    } else {
 7219        d_X = d_Qx;
 7220        x_buf_offset = qx_buf_offset;
 7221        GGML_ASSERT(qx_sz == x_sz);
 7222    }
 7223    if (qy_needs_dequant) {
 7224        d_Y = ctx->prealloc_y;
 7225        GGML_ASSERT(d_Y->size >= y_sz);
 7226    } else if (quantize_y) {
 7227        d_Y = ctx->prealloc_y;
 7228        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);
 7229    } else {
 7230        d_Y = d_Qy;
 7231        y_buf_offset = qy_buf_offset;
 7232        GGML_ASSERT(qy_sz == y_sz);
 7233    }
 7234
 7235    if (x_non_contig || qx_needs_dequant) {
 7236        if (ctx->prealloc_x_need_sync) {
 7237            ggml_vk_sync_buffers(ctx, subctx);
 7238        }
 7239    }
 7240
 7241    if (x_non_contig) {
 7242        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
 7243    } else if (qx_needs_dequant) {
 7244        const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
 7245        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)(x_ne), 1, 1});
 7246        ggml_vk_sync_buffers(ctx, subctx);
 7247    }
 7248    if (y_non_contig) {
 7249        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
 7250            ctx->prealloc_y_last_tensor_used != src1) {
 7251            if (ctx->prealloc_y_need_sync) {
 7252                ggml_vk_sync_buffers(ctx, subctx);
 7253            }
 7254            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
 7255            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
 7256            ctx->prealloc_y_last_tensor_used = src1;
 7257        }
 7258    }
 7259    if (quantize_y) {
 7260        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
 7261            ctx->prealloc_y_last_tensor_used != src1) {
 7262            if (ctx->prealloc_y_need_sync) {
 7263                ggml_vk_sync_buffers(ctx, subctx);
 7264            }
 7265            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
 7266            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
 7267            ctx->prealloc_y_last_tensor_used = src1;
 7268        }
 7269    }
 7270
 7271    uint32_t stride_batch_x = ne00*ne01;
 7272    uint32_t stride_batch_y = ne10*ne11;
 7273
 7274    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
 7275        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
 7276    }
 7277
 7278    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
 7279        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
 7280    }
 7281
 7282    // compute
 7283    ggml_vk_matmul(
 7284        ctx, subctx, pipeline,
 7285        { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
 7286        ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * split_k },
 7287        ne01, ne11, ne10,
 7288        ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
 7289        split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
 7290    );  // NOLINT
 7291
 7292    if (x_non_contig || qx_needs_dequant) {
 7293        ctx->prealloc_x_need_sync = true;
 7294    }
 7295    if (y_non_contig || quantize_y) {
 7296        ctx->prealloc_y_need_sync = true;
 7297    }
 7298}
 7299
 7300// Device tuning
 7301static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) {
 7302    if (device->mmvq_mode == 1) {
 7303        return true;
 7304    } else if (device->mmvq_mode == -1) {
 7305        return false;
 7306    }
 7307
 7308    // General performance issue with q3_k and q6_k due to 2-byte alignment
 7309    if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
 7310        return false;
 7311    }
 7312
 7313    // MMVQ is generally good for batches
 7314    if (n > 1) {
 7315        return true;
 7316    }
 7317
 7318    // Quantization overhead is not worth it for small k
 7319    switch (device->vendor_id) {
 7320    case VK_VENDOR_ID_NVIDIA:
 7321        if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
 7322            return true;
 7323        }
 7324
 7325        if (k <= 4096) {
 7326            return false;
 7327        }
 7328
 7329        switch (src0_type) {
 7330        case GGML_TYPE_MXFP4:
 7331        case GGML_TYPE_Q8_0:
 7332            return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
 7333        default:
 7334            return true;
 7335        }
 7336    case VK_VENDOR_ID_AMD:
 7337        if (k < 2048) {
 7338            return false;
 7339        }
 7340
 7341        switch (src0_type) {
 7342        case GGML_TYPE_Q8_0:
 7343            return device->architecture == vk_device_architecture::AMD_GCN;
 7344        default:
 7345            return true;
 7346        }
 7347    case VK_VENDOR_ID_INTEL:
 7348        if (k < 2048) {
 7349            return false;
 7350        }
 7351
 7352        switch (src0_type) {
 7353        // From tests on A770 Linux, may need more tuning
 7354        case GGML_TYPE_Q4_0:
 7355        case GGML_TYPE_Q5_1:
 7356            return false;
 7357        default:
 7358            return true;
 7359        }
 7360    default:
 7361        return true;
 7362    }
 7363
 7364    GGML_UNUSED(m);
 7365}
 7366
 7367static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
 7368    ggml_tensor * dst = cgraph->nodes[node_idx];
 7369    const ggml_tensor * src0 = dst->src[0];
 7370    const ggml_tensor * src1 = dst->src[1];
 7371
 7372    VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 7373    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 7374    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 7375    std::cerr << ")),)");
 7376    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT
 7377    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
 7378
 7379    const uint64_t ne00 = src0->ne[0];
 7380    const uint64_t ne01 = src0->ne[1];
 7381    const uint64_t ne02 = src0->ne[2];
 7382    const uint64_t ne03 = src0->ne[3];
 7383
 7384    const uint64_t ne10 = src1->ne[0];
 7385    const uint64_t ne11 = src1->ne[1];
 7386    const uint64_t ne12 = src1->ne[2];
 7387    const uint64_t ne13 = src1->ne[3];
 7388
 7389    const uint64_t ne20 = dst->ne[0];
 7390    const uint64_t ne21 = dst->ne[1];
 7391    // const uint64_t ne22 = dst->ne[2];
 7392    // const uint64_t ne23 = dst->ne[3];
 7393
 7394    const uint64_t r2 = ne12 / ne02;
 7395    const uint64_t r3 = ne13 / ne03;
 7396
 7397    // batch_n indicates that we need to compute a few vector results, and this assumes
 7398    // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
 7399    GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
 7400    bool batch_n = ne11 > 1;
 7401
 7402    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
 7403    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
 7404
 7405    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
 7406    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
 7407
 7408    vk_pipeline to_fp16_vk_0 = nullptr;
 7409    vk_pipeline to_fp16_vk_1 = nullptr;
 7410    if (x_non_contig) {
 7411        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
 7412    }
 7413    if (y_non_contig) {
 7414        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
 7415    } else {
 7416        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
 7417    }
 7418
 7419    // Check for mmq first
 7420    vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;
 7421    vk_pipeline to_q8_1 = nullptr;
 7422
 7423    if (dmmv == nullptr) {
 7424        // Fall back to f16 dequant mul mat
 7425        dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
 7426        quantize_y = false;
 7427    }
 7428
 7429    if (quantize_y) {
 7430        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 7431    }
 7432
 7433    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
 7434        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
 7435    }
 7436
 7437    const bool qx_needs_dequant = x_non_contig;
 7438    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
 7439
 7440    // Not implemented
 7441    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 7442
 7443    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
 7444    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 7445    GGML_ASSERT(dmmv != nullptr);
 7446
 7447    const uint64_t x_ne = ggml_nelements(src0);
 7448    const uint64_t y_ne = ggml_nelements(src1);
 7449
 7450    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
 7451    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
 7452    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
 7453                         (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
 7454
 7455    {
 7456        if (
 7457                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
 7458                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
 7459            GGML_ABORT("Requested preallocation size is too large");
 7460        }
 7461        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
 7462            ctx->prealloc_size_x = x_sz;
 7463            ggml_vk_preallocate_buffers(ctx, subctx);
 7464        }
 7465        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
 7466            ctx->prealloc_size_y = y_sz;
 7467            ggml_vk_preallocate_buffers(ctx, subctx);
 7468        }
 7469
 7470        // Request descriptor sets
 7471        if (qx_needs_dequant) {
 7472            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
 7473        }
 7474        if (qy_needs_dequant) {
 7475            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
 7476        }
 7477        if (quantize_y) {
 7478            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
 7479        }
 7480        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
 7481    }
 7482
 7483    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
 7484    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
 7485    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
 7486    vk_subbuffer d_X, d_Y;
 7487
 7488    if (qx_needs_dequant) {
 7489        d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
 7490    } else {
 7491        d_X = d_Qx;
 7492        GGML_ASSERT(qx_sz == x_sz);
 7493    }
 7494    if (qy_needs_dequant || quantize_y) {
 7495        d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
 7496    } else {
 7497        d_Y = d_Qy;
 7498    }
 7499
 7500    if (x_non_contig) {
 7501        if (ctx->prealloc_x_need_sync) {
 7502            ggml_vk_sync_buffers(ctx, subctx);
 7503        }
 7504
 7505        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
 7506        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);
 7507    }
 7508    if (y_non_contig) {
 7509        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
 7510        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
 7511            ctx->prealloc_y_last_tensor_used != src1) {
 7512            if (ctx->prealloc_y_need_sync) {
 7513                ggml_vk_sync_buffers(ctx, subctx);
 7514            }
 7515            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
 7516            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
 7517            ctx->prealloc_y_last_tensor_used = src1;
 7518        }
 7519    }
 7520    if (quantize_y) {
 7521        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
 7522            ctx->prealloc_y_last_tensor_used != src1) {
 7523            if (ctx->prealloc_y_need_sync) {
 7524                ggml_vk_sync_buffers(ctx, subctx);
 7525            }
 7526            ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
 7527            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
 7528            ctx->prealloc_y_last_tensor_used = src1;
 7529        }
 7530    }
 7531
 7532    // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
 7533    uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
 7534    uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
 7535    uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
 7536
 7537    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
 7538        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
 7539    }
 7540
 7541    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
 7542        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
 7543    }
 7544
 7545    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
 7546
 7547    uint32_t groups_x = ne01;
 7548    uint32_t groups_z = 1;
 7549
 7550    if (ne01 > max_groups_x) {
 7551        groups_z = 64;
 7552        groups_x = CEIL_DIV(groups_x, groups_z);
 7553    }
 7554
 7555    uint32_t fusion_flags = 0;
 7556
 7557    vk_subbuffer d_F0 = d_D;
 7558    if (ctx->num_additional_fused_ops > 0) {
 7559        const ggml_tensor * add = cgraph->nodes[node_idx + 1];
 7560        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
 7561
 7562        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
 7563        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
 7564    }
 7565
 7566    vk_subbuffer d_F1 = d_D;
 7567    if (ctx->num_additional_fused_ops == 2) {
 7568        const ggml_tensor * add = cgraph->nodes[node_idx + 2];
 7569        const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];
 7570
 7571        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
 7572        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
 7573    }
 7574
 7575    // compute
 7576    const vk_mat_vec_push_constants pc = {
 7577        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
 7578        stride_batch_x, stride_batch_y, stride_batch_d,
 7579        fusion_flags,
 7580        (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
 7581    };
 7582    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
 7583                              {
 7584                                d_X,
 7585                                d_Y,
 7586                                d_D,
 7587                                d_F0,
 7588                                d_F1,
 7589                              },
 7590                              pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
 7591
 7592    if (x_non_contig) {
 7593        ctx->prealloc_x_need_sync = true;
 7594    }
 7595    if (y_non_contig || quantize_y) {
 7596        ctx->prealloc_y_need_sync = true;
 7597    }
 7598}
 7599
 7600static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
 7601    ggml_tensor * dst = cgraph->nodes[node_idx];
 7602    const ggml_tensor * src0 = dst->src[0];
 7603    const ggml_tensor * src1 = dst->src[1];
 7604    VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 7605    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 7606    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 7607    std::cerr << "))");
 7608    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
 7609    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]);  // NOLINT
 7610    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]);  // NOLINT
 7611    GGML_ASSERT(src0->type == GGML_TYPE_F16);
 7612    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 7613
 7614    const uint64_t ne00 = src0->ne[0];
 7615    const uint64_t ne01 = src0->ne[1];
 7616    const uint64_t ne02 = src0->ne[2];
 7617    // const uint64_t ne03 = src0->ne[3];
 7618
 7619    //const uint64_t ne10 = src1->ne[0];
 7620    const uint64_t ne11 = src1->ne[1];
 7621    const uint64_t ne12 = src1->ne[2];
 7622    // const uint64_t ne13 = src1->ne[3];
 7623
 7624    GGML_ASSERT(ne11 == 1);
 7625
 7626    // With grouped query attention there are > 1 Q matrices per K, V matrix.
 7627    uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
 7628    if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
 7629        gqa_ratio = 1;
 7630    }
 7631
 7632    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
 7633
 7634    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
 7635        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
 7636    }
 7637
 7638    {
 7639        // Request descriptor sets
 7640        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 7641    }
 7642
 7643    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
 7644    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
 7645    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
 7646
 7647    vk_subbuffer d_F0 = d_D;
 7648
 7649    uint32_t fusion_flags = 0;
 7650
 7651    if (ctx->num_additional_fused_ops > 0) {
 7652        const ggml_tensor * add = cgraph->nodes[node_idx + 1];
 7653        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
 7654
 7655        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
 7656        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
 7657    }
 7658
 7659    vk_subbuffer d_F1 = d_D;
 7660    if (ctx->num_additional_fused_ops > 1) {
 7661        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
 7662
 7663        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
 7664        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
 7665    }
 7666
 7667    // compute
 7668
 7669    vk_mat_vec_p021_push_constants pc = {
 7670        (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,
 7671        0, 0, fusion_flags
 7672    };
 7673
 7674    init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
 7675
 7676    uint32_t workgroups_z = (uint32_t)ne12;
 7677    // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
 7678    if (gqa_ratio > 1) {
 7679        workgroups_z /= gqa_ratio;
 7680    }
 7681
 7682    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
 7683        {
 7684            d_Qx,
 7685            d_Qy,
 7686            d_D,
 7687            d_F0,
 7688            d_F1,
 7689        }, pc, { 1, (uint32_t)ne01, workgroups_z });
 7690}
 7691
 7692static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
 7693    ggml_tensor * dst = cgraph->nodes[node_idx];
 7694    const ggml_tensor * src0 = dst->src[0];
 7695    const ggml_tensor * src1 = dst->src[1];
 7696    VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 7697    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 7698    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 7699    std::cerr << "))");
 7700    GGML_ASSERT(!ggml_is_transposed(src0));
 7701    GGML_ASSERT(!ggml_is_transposed(src1));
 7702    GGML_ASSERT(!ggml_is_permuted(src0));
 7703    GGML_ASSERT(src0->type == GGML_TYPE_F16);
 7704    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 7705
 7706    const uint64_t ne00 = src0->ne[0];
 7707    const uint64_t ne01 = src0->ne[1];
 7708    const uint64_t ne02 = src0->ne[2];
 7709    const uint64_t ne03 = src0->ne[3];
 7710
 7711    const uint64_t nb01 = src0->nb[1];
 7712    const uint64_t nb02 = src0->nb[2];
 7713
 7714    const uint64_t nb12 = src1->nb[2];
 7715
 7716    // const uint64_t ne10 = src1->ne[0];
 7717    const uint64_t ne11 = src1->ne[1];
 7718    const uint64_t ne12 = src1->ne[2];
 7719    // const uint64_t ne13 = src1->ne[3];
 7720
 7721    const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
 7722    const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
 7723    const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
 7724
 7725    GGML_ASSERT(ne11 == 1);
 7726    GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
 7727
 7728    const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
 7729    const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
 7730    const uint32_t channel_stride_y = nb12 / sizeof(float);
 7731
 7732    vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
 7733    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
 7734        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
 7735    }
 7736
 7737    {
 7738        // Request descriptor sets
 7739        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 7740    }
 7741
 7742    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
 7743    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
 7744    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
 7745    vk_subbuffer d_F0 = d_D;
 7746
 7747    uint32_t fusion_flags = 0;
 7748
 7749    if (ctx->num_additional_fused_ops > 0) {
 7750        const ggml_tensor * add = cgraph->nodes[node_idx + 1];
 7751        const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
 7752
 7753        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
 7754        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
 7755    }
 7756
 7757    vk_subbuffer d_F1 = d_D;
 7758    if (ctx->num_additional_fused_ops > 1) {
 7759        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
 7760
 7761        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
 7762        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
 7763    }
 7764
 7765    // compute
 7766    vk_mat_vec_nc_push_constants pc = {
 7767        (uint32_t)ne00, (uint32_t)ne01,
 7768        row_stride_x, channel_stride_x, channel_stride_y,
 7769        (uint32_t)(ne12 / ne02), (uint32_t)ne12,
 7770        0, 0,
 7771        nb03, nb13, nb23, fusion_flags
 7772    };
 7773
 7774    init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
 7775
 7776    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
 7777        {
 7778            d_Qx,
 7779            d_Qy,
 7780            d_D,
 7781            d_F0,
 7782            d_F1,
 7783        }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
 7784}
 7785
 7786static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
 7787    ggml_tensor * dst = cgraph->nodes[node_idx];
 7788    ggml_tensor * src0 = dst->src[0];
 7789    ggml_tensor * src1 = dst->src[1];
 7790    VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
 7791
 7792    // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
 7793    // where the M dimension is very large.
 7794    // Split_k doesn't work with M splitting.
 7795    // This only supports batchsize == 1.
 7796    const size_t nbytes = ggml_nbytes(src0);
 7797    const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
 7798    if (needs_split) {
 7799        // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
 7800        const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
 7801        uint32_t m_offset = 0;
 7802        while (m_offset < dst->ne[0]) {
 7803            const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
 7804            ggml_tensor dst2 = *dst;
 7805            ggml_tensor src02 = *src0;
 7806
 7807            dst2.view_src = dst->view_src ? dst->view_src : dst;
 7808            src02.view_src = src0->view_src ? src0->view_src : src0;
 7809
 7810            dst2.view_offs += m_offset * dst->nb[0];
 7811            src02.view_offs += m_offset * src0->nb[1];
 7812            dst2.ne[0] = cur_M_size;
 7813            src02.ne[1] = cur_M_size;
 7814
 7815            ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true);
 7816
 7817            m_offset += cur_M_size;
 7818        }
 7819    } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
 7820        // detect 0213 permutation, and batch size of 1
 7821        src0->nb[0] <= src0->nb[2] &&
 7822        src0->nb[2] <= src0->nb[1] &&
 7823        src0->nb[1] <= src0->nb[3] &&
 7824        src1->nb[0] <= src1->nb[2] &&
 7825        src1->nb[2] <= src1->nb[1] &&
 7826        src1->nb[1] <= src1->nb[3] &&
 7827        src0->ne[3] == 1 &&
 7828        src1->ne[3] == 1) {
 7829        ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
 7830    } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
 7831               !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
 7832        ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
 7833    // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
 7834    // when ne12 and ne13 are one.
 7835    } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
 7836               (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
 7837        ggml_vk_mul_mat_vec_q_f16(ctx, subctx, cgraph, node_idx);
 7838    } else {
 7839        ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false);
 7840    }
 7841}
 7842
 7843static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
 7844    VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 7845    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 7846    std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
 7847    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
 7848    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
 7849    GGML_ASSERT(ids->type == GGML_TYPE_I32);
 7850
 7851    const uint64_t ne00 = src0->ne[0];
 7852    const uint64_t ne01 = src0->ne[1];
 7853    const uint64_t ne02 = src0->ne[2];
 7854    // const uint64_t ne03 = src0->ne[3];
 7855
 7856    const uint64_t ne10 = src1->ne[0];
 7857    const uint64_t ne11 = src1->ne[1];
 7858    const uint64_t ne12 = src1->ne[2];
 7859    const uint64_t ne13 = src1->ne[3];
 7860
 7861    const uint64_t nei0 = ids->ne[0];
 7862    const uint64_t nei1 = ids->ne[1];
 7863
 7864    const uint32_t nbi0 = ids->nb[0];
 7865    const uint32_t nbi1 = ids->nb[1];
 7866    const uint32_t nbi2 = ids->nb[2];
 7867
 7868    const uint64_t ne20 = dst->ne[0];
 7869    const uint64_t ne21 = dst->ne[1];
 7870    // const uint64_t ne22 = dst->ne[2];
 7871    // const uint64_t ne23 = dst->ne[3];
 7872
 7873    const uint64_t n_as = ne02;
 7874
 7875    ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
 7876    ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
 7877    ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
 7878    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
 7879
 7880    vk_buffer d_Qx = nullptr;
 7881    size_t qx_buf_offset = 0;
 7882    vk_buffer d_Qy = nullptr;
 7883    size_t qy_buf_offset = 0;
 7884    vk_buffer d_ids = nullptr;
 7885    size_t ids_buf_offset = 0;
 7886
 7887    bool src0_uma = false;
 7888    bool src1_uma = false;
 7889    bool ids_uma = false;
 7890
 7891    if (ctx->device->uma) {
 7892        ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
 7893        ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
 7894        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
 7895        src0_uma = d_Qx != nullptr;
 7896        src1_uma = d_Qy != nullptr;
 7897        ids_uma = d_ids != nullptr;
 7898    }
 7899
 7900    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
 7901    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
 7902                              !ggml_vk_dim01_contiguous(src0);
 7903    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
 7904                              (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
 7905                              !ggml_vk_dim01_contiguous(src1);
 7906
 7907    // If src0 is BF16, try to use a BF16 x BF16 multiply
 7908    ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
 7909
 7910    const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 7911
 7912    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
 7913
 7914    // Check for mmq first
 7915    vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
 7916
 7917    if (mmp == nullptr) {
 7918        // Fall back to f16 dequant mul mat
 7919        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
 7920        quantize_y = false;
 7921    }
 7922
 7923    const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
 7924    const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
 7925
 7926    if (qx_needs_dequant) {
 7927        // Fall back to dequant + f16 mulmat
 7928        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
 7929    }
 7930
 7931    // Not implemented
 7932    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 7933
 7934    const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
 7935    const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
 7936
 7937    vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
 7938
 7939    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
 7940        pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
 7941    }
 7942    // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
 7943    uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
 7944    const uint64_t x_ne = ggml_nelements(src0);
 7945    const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
 7946    const uint64_t d_ne = ggml_nelements(dst);
 7947
 7948    const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
 7949    const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
 7950    const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
 7951    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
 7952    const uint64_t ids_sz = nbi2;
 7953    const uint64_t d_sz = sizeof(float) * d_ne;
 7954
 7955    vk_pipeline to_fp16_vk_0 = nullptr;
 7956    vk_pipeline to_fp16_vk_1 = nullptr;
 7957    vk_pipeline to_q8_1 = nullptr;
 7958
 7959    if (x_non_contig) {
 7960        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
 7961    } else {
 7962        to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
 7963    }
 7964    if (y_non_contig) {
 7965        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
 7966    } else {
 7967        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
 7968    }
 7969    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
 7970    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 7971
 7972    if (quantize_y) {
 7973        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 7974    }
 7975    vk_pipeline count_experts = ctx->device->pipeline_count_experts;
 7976
 7977    uint32_t expert_count_size = sizeof(uint32_t) * n_as;
 7978
 7979    {
 7980        if (
 7981                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
 7982                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
 7983            GGML_ABORT("Requested preallocation size is too large");
 7984        }
 7985        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
 7986            ctx->prealloc_size_x = x_sz;
 7987            ggml_vk_preallocate_buffers(ctx, subctx);
 7988        }
 7989        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
 7990            ctx->prealloc_size_y = y_sz;
 7991            ggml_vk_preallocate_buffers(ctx, subctx);
 7992        }
 7993        if (ctx->prealloc_size_split_k < expert_count_size) {
 7994            ctx->prealloc_size_split_k = expert_count_size;
 7995            ggml_vk_preallocate_buffers(ctx, subctx);
 7996        }
 7997
 7998        // Request descriptor sets
 7999        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 8000        if (qx_needs_dequant) {
 8001            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
 8002        }
 8003        if (qy_needs_dequant) {
 8004            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
 8005        }
 8006        if (quantize_y) {
 8007            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
 8008        }
 8009        ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
 8010    }
 8011
 8012    vk_buffer d_D = dst_buf_ctx->dev_buffer;
 8013    const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
 8014    GGML_ASSERT(d_D != nullptr);
 8015    vk_buffer d_X;
 8016    uint64_t x_buf_offset = 0;
 8017    vk_buffer d_Y;
 8018    uint64_t y_buf_offset = 0;
 8019    if (!src0_uma) {
 8020        d_Qx = src0_buf_ctx->dev_buffer;
 8021        qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
 8022        GGML_ASSERT(d_Qx != nullptr);
 8023    }
 8024    if (!src1_uma) {
 8025        d_Qy = src1_buf_ctx->dev_buffer;
 8026        qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
 8027        GGML_ASSERT(d_Qy != nullptr);
 8028    }
 8029    if (!ids_uma) {
 8030        d_ids = ids_buf_ctx->dev_buffer;
 8031        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
 8032        GGML_ASSERT(d_ids != nullptr);
 8033    }
 8034    if (qx_needs_dequant) {
 8035        d_X = ctx->prealloc_x;
 8036        GGML_ASSERT(d_X->size >= x_sz);
 8037    } else {
 8038        d_X = d_Qx;
 8039        x_buf_offset = qx_buf_offset;
 8040        GGML_ASSERT(qx_sz == x_sz);
 8041    }
 8042    if (qy_needs_dequant) {
 8043        d_Y = ctx->prealloc_y;
 8044        GGML_ASSERT(d_Y->size >= y_sz);
 8045    } else if (quantize_y) {
 8046        d_Y = ctx->prealloc_y;
 8047        GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);
 8048    } else {
 8049        d_Y = d_Qy;
 8050        y_buf_offset = qy_buf_offset;
 8051        GGML_ASSERT(qy_sz == y_sz);
 8052    }
 8053
 8054    if (x_non_contig || qx_needs_dequant) {
 8055        if (ctx->prealloc_x_need_sync) {
 8056            ggml_vk_sync_buffers(ctx, subctx);
 8057        }
 8058    }
 8059    // Count how many times each expert is used
 8060    vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
 8061    if (ctx->prealloc_split_k_need_sync) {
 8062        ggml_vk_sync_buffers(ctx, subctx);
 8063    }
 8064    {
 8065        const std::vector<uint32_t> pc = { (uint32_t)nei0,
 8066                                           (uint32_t)nei1,
 8067                                           (uint32_t)(nbi0 / ggml_type_size(ids->type)),
 8068                                           (uint32_t)(nbi1 / ggml_type_size(ids->type)),
 8069                                           (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
 8070        ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
 8071            { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
 8072    }
 8073
 8074    if (x_non_contig) {
 8075        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
 8076    } else if (qx_needs_dequant) {
 8077        const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
 8078        ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
 8079            { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
 8080    }
 8081    if (y_non_contig) {
 8082        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
 8083            ctx->prealloc_y_last_tensor_used != src1) {
 8084            if (ctx->prealloc_y_need_sync) {
 8085                ggml_vk_sync_buffers(ctx, subctx);
 8086            }
 8087            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
 8088            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
 8089            ctx->prealloc_y_last_tensor_used = src1;
 8090        }
 8091    }
 8092    if (quantize_y) {
 8093        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
 8094            ctx->prealloc_y_last_tensor_used != src1) {
 8095            if (ctx->prealloc_y_need_sync) {
 8096                ggml_vk_sync_buffers(ctx, subctx);
 8097            }
 8098            ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
 8099            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
 8100            ctx->prealloc_y_last_tensor_used = src1;
 8101        }
 8102    }
 8103    ggml_vk_sync_buffers(ctx, subctx);
 8104
 8105    uint32_t stride_batch_x = ne00*ne01;
 8106    uint32_t stride_batch_y = ne10*ne11;
 8107
 8108    if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
 8109        stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
 8110    }
 8111
 8112    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
 8113        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
 8114    }
 8115
 8116    // compute
 8117    ggml_vk_matmul_id(
 8118        ctx, subctx, pipeline,
 8119        { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
 8120        { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
 8121        ne01, ne21, ne10, ne10, ne10, ne01,
 8122        stride_batch_x, stride_batch_y, ne20*ne21,
 8123        n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
 8124    );  // NOLINT
 8125
 8126    if (x_non_contig || qx_needs_dequant) {
 8127        ctx->prealloc_x_need_sync = true;
 8128    }
 8129    if (y_non_contig || quantize_y) {
 8130        ctx->prealloc_y_need_sync = true;
 8131    }
 8132    ctx->prealloc_split_k_need_sync = true;
 8133}
 8134
 8135static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
 8136    ggml_tensor * dst = cgraph->nodes[node_idx];
 8137    ggml_tensor * src0 = dst->src[0];
 8138    ggml_tensor * src1 = dst->src[1];
 8139    ggml_tensor * ids = dst->src[2];
 8140    VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 8141    std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 8142    std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
 8143    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 8144    std::cerr << "))");
 8145    GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16);  // NOLINT
 8146    GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);  // NOLINT
 8147    GGML_ASSERT(ids->type == GGML_TYPE_I32);
 8148
 8149    const uint64_t ne00 = src0->ne[0];
 8150    const uint64_t ne01 = src0->ne[1];
 8151    // const uint64_t ne02 = src0->ne[2];
 8152    // const uint64_t ne03 = src0->ne[3];
 8153
 8154    const uint64_t ne10 = src1->ne[0];
 8155    const uint64_t ne11 = src1->ne[1];
 8156    const uint64_t ne12 = src1->ne[2];
 8157    // const uint64_t ne13 = src1->ne[3];
 8158
 8159    const uint64_t nei0 = ids->ne[0];
 8160    const uint64_t nei1 = ids->ne[1];
 8161    const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
 8162
 8163    const uint64_t ne20 = dst->ne[0];
 8164    const uint64_t ne21 = dst->ne[1];
 8165    // const uint64_t ne22 = dst->ne[2];
 8166    // const uint64_t ne23 = dst->ne[3];
 8167
 8168    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
 8169    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
 8170
 8171    const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
 8172    bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
 8173
 8174    vk_pipeline to_fp16_vk_0 = nullptr;
 8175    vk_pipeline to_fp16_vk_1 = nullptr;
 8176    if (x_non_contig) {
 8177        to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
 8178    }
 8179    if (y_non_contig) {
 8180        to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
 8181    } else {
 8182        to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
 8183    }
 8184
 8185    // Check for mmq first
 8186    vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
 8187    vk_pipeline to_q8_1 = nullptr;
 8188
 8189    if (dmmv == nullptr) {
 8190        // Fall back to f16 dequant mul mat
 8191        dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
 8192        quantize_y = false;
 8193    }
 8194
 8195    if (quantize_y) {
 8196        to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
 8197    }
 8198
 8199    const bool qx_needs_dequant = x_non_contig;
 8200    const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
 8201
 8202    if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
 8203        dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
 8204    }
 8205
 8206    // Not implemented
 8207    GGML_ASSERT(y_non_contig || !qy_needs_dequant);  // NOLINT
 8208    GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr);  // NOLINT
 8209    GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr);  // NOLINT
 8210    GGML_ASSERT(dmmv != nullptr);
 8211
 8212    const uint64_t x_ne = ggml_nelements(src0);
 8213    const uint64_t y_ne = ggml_nelements(src1);
 8214
 8215    const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
 8216    const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
 8217    const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
 8218                                       (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
 8219
 8220    {
 8221        if (
 8222                (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
 8223                (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
 8224            GGML_ABORT("Requested preallocation size is too large");
 8225        }
 8226        if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
 8227            ctx->prealloc_size_x = x_sz;
 8228            ggml_vk_preallocate_buffers(ctx, subctx);
 8229        }
 8230        if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
 8231            ctx->prealloc_size_y = y_sz;
 8232            ggml_vk_preallocate_buffers(ctx, subctx);
 8233        }
 8234
 8235        // Request descriptor sets
 8236        if (qx_needs_dequant) {
 8237            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
 8238        }
 8239        if (qy_needs_dequant) {
 8240            ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
 8241        }
 8242        if (quantize_y) {
 8243            ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
 8244        }
 8245        ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
 8246    }
 8247
 8248    vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
 8249    vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
 8250    vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
 8251    vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);
 8252    vk_subbuffer d_F0 = d_D;
 8253    vk_subbuffer d_X, d_Y;
 8254
 8255    if (qx_needs_dequant) {
 8256        d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
 8257    } else {
 8258        d_X = d_Qx;
 8259    }
 8260    if (qy_needs_dequant || quantize_y) {
 8261        d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
 8262    } else {
 8263        d_Y = d_Qy;
 8264    }
 8265
 8266    if (x_non_contig) {
 8267        if (ctx->prealloc_x_need_sync) {
 8268            ggml_vk_sync_buffers(ctx, subctx);
 8269        }
 8270    }
 8271
 8272    if (x_non_contig) {
 8273        GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
 8274        ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);
 8275    }
 8276    if (y_non_contig) {
 8277        GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
 8278        if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
 8279            ctx->prealloc_y_last_tensor_used != src1) {
 8280            if (ctx->prealloc_y_need_sync) {
 8281                ggml_vk_sync_buffers(ctx, subctx);
 8282            }
 8283            ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
 8284            ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
 8285            ctx->prealloc_y_last_tensor_used = src1;
 8286        }
 8287    }
 8288    if (quantize_y) {
 8289        if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
 8290            ctx->prealloc_y_last_tensor_used != src1) {
 8291            if (ctx->prealloc_y_need_sync) {
 8292                ggml_vk_sync_buffers(ctx, subctx);
 8293            }
 8294            ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
 8295            ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
 8296            ctx->prealloc_y_last_tensor_used = src1;
 8297        }
 8298    }
 8299
 8300    uint32_t stride_batch_y = ne10*ne11;
 8301
 8302    if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
 8303        stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
 8304    }
 8305
 8306    const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
 8307
 8308    uint32_t groups_x = ne01;
 8309    uint32_t groups_z = 1;
 8310
 8311    if (ne01 > max_groups_x) {
 8312        groups_z = 64;
 8313        groups_x = CEIL_DIV(groups_x, groups_z);
 8314    }
 8315
 8316    uint32_t fusion_flags = 0;
 8317
 8318    if (ctx->num_additional_fused_ops > 0) {
 8319        const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
 8320
 8321        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
 8322
 8323        if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
 8324            fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;
 8325        } else {
 8326            GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
 8327            fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
 8328        }
 8329    }
 8330
 8331    vk_subbuffer d_F1 = d_D;
 8332    if (ctx->num_additional_fused_ops > 1) {
 8333        const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];
 8334
 8335        d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);
 8336        fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
 8337    }
 8338
 8339    // Loop over the batch dimension
 8340    for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
 8341        const vk_mat_vec_id_push_constants pc = {
 8342            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
 8343            (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
 8344            fusion_flags,
 8345            (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
 8346        };
 8347        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
 8348            {
 8349                d_X,
 8350                d_Y,
 8351                d_D,
 8352                d_F0,
 8353                d_F1,
 8354                d_ids,
 8355            },
 8356            pc, { groups_x, (uint32_t)nei0, groups_z });
 8357    }
 8358
 8359    if (x_non_contig) {
 8360        ctx->prealloc_x_need_sync = true;
 8361    }
 8362    if (y_non_contig || quantize_y) {
 8363        ctx->prealloc_y_need_sync = true;
 8364    }
 8365}
 8366
 8367static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx) {
 8368    ggml_tensor * dst = cgraph->nodes[node_idx];
 8369    ggml_tensor * src0 = dst->src[0];
 8370    ggml_tensor * src2 = dst->src[2];
 8371    return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
 8372}
 8373
 8374static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
 8375    ggml_tensor * dst = cgraph->nodes[node_idx];
 8376    ggml_tensor * src0 = dst->src[0];
 8377    ggml_tensor * src1 = dst->src[1];
 8378    ggml_tensor * src2 = dst->src[2];
 8379    VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
 8380    if (ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
 8381        ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, cgraph, node_idx);
 8382    } else {
 8383        ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);
 8384    }
 8385}
 8386
 8387static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
 8388    // Needs to be kept up to date on shader changes
 8389    GGML_UNUSED(hsv);
 8390    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
 8391    const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
 8392    const uint32_t Bc = scalar_flash_attention_Bc;
 8393
 8394    const uint32_t tmpsh = wg_size * sizeof(float);
 8395    const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
 8396
 8397    const uint32_t masksh = Bc * Br * sizeof(float);
 8398
 8399    const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
 8400
 8401    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
 8402    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 8403
 8404    VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
 8405
 8406    return supported;
 8407}
 8408
 8409static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
 8410    // Needs to be kept up to date on shader changes
 8411    GGML_UNUSED(hsv);
 8412    const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
 8413    const uint32_t Br = rows_cols[0];
 8414    const uint32_t Bc = rows_cols[1];
 8415
 8416    const uint32_t MatBr = 16, MatBc = 16;
 8417
 8418    const uint32_t row_split = Bc / MatBc;
 8419
 8420    const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
 8421
 8422    const uint32_t acctype = f32acc ? 4 : 2;
 8423    const uint32_t f16vec4 = 8;
 8424
 8425    const uint32_t qstride = hsk_pad / 4 + 2;
 8426    const uint32_t Qf = Br * qstride * f16vec4;
 8427
 8428    const uint32_t psh_stride = Br / 4 + 2;
 8429    const uint32_t Psh = Bc * psh_stride * f16vec4;
 8430
 8431    const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
 8432    const uint32_t sfsh = Bc * sfshstride * acctype;
 8433
 8434    const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
 8435    const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
 8436    const uint32_t vsh_stride = MatBc / 4 * row_split;
 8437    const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
 8438
 8439    const uint32_t slope = Br * acctype;
 8440
 8441    const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
 8442    const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 8443
 8444    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
 8445
 8446    return supported;
 8447}
 8448
 8449static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) {
 8450    VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
 8451    std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
 8452    std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
 8453    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 8454    if (sinks) {
 8455        std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
 8456    }
 8457    std::cerr << "))");
 8458
 8459    GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
 8460    GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
 8461    GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
 8462    GGML_TENSOR_LOCALS(size_t,  nbk, k,   nb)
 8463    GGML_TENSOR_LOCALS(int64_t, nev, v,   ne)
 8464    GGML_TENSOR_LOCALS(size_t,  nbv, v,   nb)
 8465    GGML_TENSOR_LOCALS(int64_t, ne,  dst, ne)
 8466    GGML_TENSOR_LOCALS(size_t,  nb,  dst, nb)
 8467
 8468    const uint32_t nem0 = mask ? mask->ne[0] : 0;
 8469    const uint32_t nem1 = mask ? mask->ne[1] : 0;
 8470    const uint32_t nem2 = mask ? mask->ne[2] : 0;
 8471    const uint32_t nem3 = mask ? mask->ne[3] : 0;
 8472
 8473    const uint32_t HSK = nek0;
 8474    const uint32_t HSV = nev0;
 8475    uint32_t N = neq1;
 8476    const uint32_t KV = nek1;
 8477
 8478    GGML_ASSERT(ne0 == HSV);
 8479    GGML_ASSERT(ne2 == N);
 8480
 8481    // input tensor rows must be contiguous
 8482    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
 8483    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
 8484    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 8485
 8486    GGML_ASSERT(neq0 == HSK);
 8487
 8488    GGML_ASSERT(neq1 == N);
 8489
 8490    GGML_ASSERT(nev1 == nek1);
 8491
 8492    // dst cannot be transposed or permuted
 8493    GGML_ASSERT(nb0 == sizeof(float));
 8494    GGML_ASSERT(nb0 <= nb1);
 8495    GGML_ASSERT(nb1 <= nb2);
 8496    GGML_ASSERT(nb2 <= nb3);
 8497
 8498    assert(dst->type == GGML_TYPE_F32);
 8499    assert(q->type == GGML_TYPE_F32);
 8500    assert(k->type == v->type);
 8501
 8502    FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
 8503                      ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
 8504
 8505    if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) {
 8506        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
 8507        path = FA_SCALAR;
 8508    }
 8509
 8510    if (path == FA_COOPMAT1) {
 8511        const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
 8512                                             (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
 8513
 8514        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
 8515
 8516        if (!coopmat_shape_supported || !coopmat_shmem_supported) {
 8517            path = FA_SCALAR;
 8518        }
 8519    }
 8520
 8521    uint32_t gqa_ratio = 1;
 8522    uint32_t qk_ratio = neq2 / nek2;
 8523    uint32_t workgroups_x = (uint32_t)neq1;
 8524    uint32_t workgroups_y = (uint32_t)neq2;
 8525    uint32_t workgroups_z = (uint32_t)neq3;
 8526
 8527    const bool small_cache = nek1 < 1024;
 8528
 8529    // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
 8530    // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
 8531    uint32_t max_gqa;
 8532    switch (path) {
 8533    case FA_SCALAR:
 8534    case FA_COOPMAT1:
 8535        // We may switch from coopmat1 to scalar, so use the scalar limit for both
 8536        max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
 8537        break;
 8538    case FA_COOPMAT2:
 8539        max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
 8540        break;
 8541    default:
 8542        GGML_ASSERT(0);
 8543    }
 8544
 8545    if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
 8546        qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
 8547        // grouped query attention - make the N dimension equal to gqa_ratio, reduce
 8548        // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
 8549        // and change addressing calculations to index Q's dimension 2.
 8550        gqa_ratio = qk_ratio;
 8551        N = gqa_ratio;
 8552        workgroups_y /= gqa_ratio;
 8553    }
 8554
 8555    bool small_rows = N <= get_fa_num_small_rows(path);
 8556
 8557    // coopmat1 does not actually support "small rows" (it needs 16 rows).
 8558    // So use scalar instead.
 8559    if (small_rows && path == FA_COOPMAT1) {
 8560        path = FA_SCALAR;
 8561    }
 8562
 8563    // scalar is faster than coopmat2 when N==1
 8564    if (N == 1 && path == FA_COOPMAT2) {
 8565        path = FA_SCALAR;
 8566    }
 8567
 8568    // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
 8569    if (path == FA_SCALAR &&
 8570        !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
 8571        small_rows = true;
 8572    }
 8573
 8574    const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
 8575    uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
 8576    uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
 8577
 8578    // For F32, the shader treats it as a block of size 4 (for vec4 loads)
 8579    if (k->type == GGML_TYPE_F32) {
 8580        k_stride /= 4;
 8581    }
 8582    if (v->type == GGML_TYPE_F32) {
 8583        v_stride /= 4;
 8584    }
 8585
 8586    uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
 8587    bool aligned = (KV % alignment) == 0 &&
 8588                   // the "aligned" shader variant will forcibly align strides, for performance
 8589                   (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
 8590
 8591    // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
 8592    if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
 8593        aligned = false;
 8594    }
 8595
 8596    bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
 8597
 8598    float scale         = 1.0f;
 8599    float max_bias      = 0.0f;
 8600    float logit_softcap = 0.0f;
 8601
 8602    memcpy(&scale,         (const float *) dst->op_params + 0, sizeof(float));
 8603    memcpy(&max_bias,      (const float *) dst->op_params + 1, sizeof(float));
 8604    memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
 8605
 8606    if (logit_softcap != 0) {
 8607        scale /= logit_softcap;
 8608    }
 8609
 8610    // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
 8611    bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
 8612
 8613    uint32_t flags = (use_mask_opt       ? 1 : 0) |
 8614                     (mask != nullptr    ? 2 : 0) |
 8615                     (logit_softcap != 0 ? 4 : 0);
 8616
 8617    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
 8618
 8619    vk_pipeline pipeline = nullptr;
 8620
 8621    {
 8622        std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
 8623        auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
 8624        auto it = pipelines.find(fa_pipeline_state);
 8625        if (it != pipelines.end()) {
 8626            pipeline = it->second;
 8627        } else {
 8628            pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
 8629        }
 8630    }
 8631
 8632    assert(pipeline);
 8633    // Compile early to initialize wg_denoms.
 8634    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 8635
 8636    uint32_t split_kv = KV;
 8637    uint32_t split_k = 1;
 8638
 8639    // Use a placeholder core count if one isn't available. split_k is a big help for perf.
 8640    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
 8641
 8642    // Try to use split_k when KV is large enough to be worth the overhead.
 8643    // Must either be a single batch or be using gqa, we can't mix the two.
 8644    if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
 8645        // Try to run two workgroups per SM.
 8646        split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
 8647        if (split_k > 1) {
 8648            // Try to evenly split KV into split_k chunks, but it needs to be a multiple
 8649            // of "align", so recompute split_k based on that.
 8650            split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
 8651            split_k = CEIL_DIV(KV, split_kv);
 8652        }
 8653    }
 8654
 8655    // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
 8656    // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
 8657    // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
 8658    // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
 8659    const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
 8660    if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
 8661        GGML_ABORT("Requested preallocation size is too large");
 8662    }
 8663    if (ctx->prealloc_size_split_k < split_k_size) {
 8664        ctx->prealloc_size_split_k = split_k_size;
 8665        ggml_vk_preallocate_buffers(ctx, subctx);
 8666    }
 8667
 8668    auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
 8669    const uint32_t Br = rows_cols[0];
 8670    const uint32_t Bc = rows_cols[1];
 8671
 8672    const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
 8673    const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
 8674
 8675    vk_pipeline pipeline_fa_mask_opt = nullptr;
 8676    if (use_mask_opt) {
 8677        std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
 8678        auto &pipelines = ctx->device->pipeline_fa_mask_opt;
 8679        auto it = pipelines.find({Br, Bc});
 8680        if (it != pipelines.end()) {
 8681            pipeline_fa_mask_opt = it->second;
 8682        } else {
 8683            pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
 8684        }
 8685        assert(pipeline_fa_mask_opt);
 8686        ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
 8687
 8688        if (ctx->prealloc_size_y < mask_opt_size) {
 8689            ctx->prealloc_size_y = mask_opt_size;
 8690            ggml_vk_preallocate_buffers(ctx, subctx);
 8691        }
 8692        if (ctx->prealloc_y_need_sync) {
 8693            ggml_vk_sync_buffers(ctx, subctx);
 8694        }
 8695    }
 8696
 8697    const uint32_t n_head_kv   = neq2;
 8698    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
 8699    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 8700    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 8701
 8702    vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
 8703    vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
 8704    vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
 8705    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
 8706    vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
 8707    vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
 8708    vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
 8709
 8710    uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
 8711
 8712    if (use_mask_opt)
 8713    {
 8714        const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
 8715            nem0,
 8716            nem1,
 8717            nem2,
 8718            (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
 8719            (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
 8720            (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
 8721            mask_opt_num_dwords,
 8722            mask_opt_num_dwords * CEIL_DIV(nem1, Br),
 8723            mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
 8724        };
 8725
 8726        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
 8727                                  { mask_buf, mask_opt_buf }, opt_pc,
 8728                                  { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
 8729        ggml_vk_sync_buffers(ctx, subctx);
 8730    }
 8731
 8732    const vk_flash_attn_push_constants pc = { N, KV,
 8733                                              (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
 8734                                              (uint32_t)neq2, (uint32_t)neq3,
 8735                                              (uint32_t)nek2, (uint32_t)nek3,
 8736                                              (uint32_t)nev2, (uint32_t)nev3,
 8737                                              nem1, nem2, nem3,
 8738                                              q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
 8739                                              k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
 8740                                              v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
 8741                                              scale, max_bias, logit_softcap,
 8742                                              mask_n_head_log2, m0, m1,
 8743                                              gqa_ratio, split_kv, split_k };
 8744
 8745    if (split_k > 1) {
 8746        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
 8747
 8748        if (ctx->prealloc_split_k_need_sync) {
 8749            ggml_vk_sync_buffers(ctx, subctx);
 8750        }
 8751        workgroups_x *= pipeline->wg_denoms[0];
 8752        vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
 8753        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
 8754                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
 8755                                    // We only use split_k when group query attention is enabled, which means
 8756                                    // there's no more than one tile of rows (i.e. workgroups_x would have been
 8757                                    // one). We reuse workgroups_x to mean the number of splits, so we need to
 8758                                    // cancel out the divide by wg_denoms[0].
 8759                                    pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
 8760
 8761        ggml_vk_sync_buffers(ctx, subctx);
 8762        const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
 8763        ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
 8764                                    {split_k_buf, sinks_buf, dst_buf},
 8765                                    pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
 8766        ctx->prealloc_split_k_need_sync = true;
 8767    } else {
 8768        if (gqa_ratio > 1) {
 8769            // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
 8770            workgroups_x *= pipeline->wg_denoms[0];
 8771        }
 8772        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
 8773                                    {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
 8774                                    pc, { workgroups_x, workgroups_y, workgroups_z });
 8775    }
 8776}
 8777
 8778static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, uint32_t K, uint32_t NPQ) {
 8779    auto n_tiles = [&](vk_conv_shapes s) {
 8780        return CEIL_DIV(K, vk_conv_block_sizes[s].K)
 8781            * CEIL_DIV(NPQ, vk_conv_block_sizes[s].NPQ);
 8782    };
 8783
 8784    // We can't query number of shader cores on Intel, use 32 as a placeholder
 8785    // so small convolutions will still choose a smaller tile.
 8786    const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
 8787
 8788    if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
 8789        return CONV_SHAPE_128x128;
 8790    } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
 8791        return CONV_SHAPE_32x256;
 8792    } else {
 8793        return CONV_SHAPE_64x32;
 8794    }
 8795}
 8796
 8797static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
 8798    switch (op) {
 8799    case GGML_OP_GET_ROWS:
 8800        GGML_ASSERT(src1->type == GGML_TYPE_I32);
 8801        if (src0->type == GGML_TYPE_I32) {
 8802            // i32 src only supports i32 result
 8803            GGML_ASSERT(dst->type == GGML_TYPE_I32);
 8804            return ctx->device->pipeline_get_rows[src0->type];
 8805        }
 8806        if (dst->type == GGML_TYPE_F16) {
 8807            return ctx->device->pipeline_get_rows[src0->type];
 8808        }
 8809        if (dst->type == GGML_TYPE_F32) {
 8810            return ctx->device->pipeline_get_rows_f32[src0->type];
 8811        }
 8812        return nullptr;
 8813    case GGML_OP_ACC:
 8814        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8815            return ctx->device->pipeline_acc_f32;
 8816        }
 8817        return nullptr;
 8818    case GGML_OP_ADD:
 8819    case GGML_OP_SUB:
 8820    case GGML_OP_MUL:
 8821    case GGML_OP_DIV:
 8822        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
 8823            (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
 8824            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
 8825            return nullptr;
 8826        }
 8827        switch (op) {
 8828        case GGML_OP_ADD:
 8829        {
 8830            if (ctx->num_additional_fused_ops > 0) {
 8831                if (ctx->do_add_rms_partials) {
 8832                    return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
 8833                } else {
 8834                    return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
 8835                }
 8836            }
 8837            if (ctx->do_add_rms_partials) {
 8838                auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
 8839                return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
 8840            } else {
 8841                auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
 8842                return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
 8843            }
 8844        }
 8845        case GGML_OP_SUB:
 8846        {
 8847            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
 8848            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
 8849        }
 8850        case GGML_OP_MUL:
 8851        {
 8852            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
 8853            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
 8854        }
 8855        case GGML_OP_DIV:
 8856        {
 8857            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
 8858            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
 8859        }
 8860        default:
 8861            break;
 8862        }
 8863        return nullptr;
 8864    case GGML_OP_ADD_ID:
 8865        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
 8866            return ctx->device->pipeline_add_id_f32;
 8867        }
 8868        return nullptr;
 8869    case GGML_OP_CONCAT:
 8870        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8871            return ctx->device->pipeline_concat_f32;
 8872        }
 8873        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 8874            return ctx->device->pipeline_concat_f16;
 8875        }
 8876        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
 8877            return ctx->device->pipeline_concat_i32;
 8878        }
 8879        return nullptr;
 8880    case GGML_OP_UPSCALE:
 8881        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8882            uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
 8883            switch (mode) {
 8884                case GGML_SCALE_MODE_NEAREST:
 8885                    return ctx->device->pipeline_upscale_nearest_f32;
 8886                case GGML_SCALE_MODE_BILINEAR:
 8887                    return ctx->device->pipeline_upscale_bilinear_f32;
 8888                case GGML_SCALE_MODE_BICUBIC:
 8889                    return ctx->device->pipeline_upscale_bicubic_f32;
 8890                case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
 8891                    return ctx->device->pipeline_upscale_bilinear_antialias_f32;
 8892                default:
 8893                    return nullptr;
 8894            }
 8895        }
 8896        return nullptr;
 8897    case GGML_OP_SCALE:
 8898        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8899            return ctx->device->pipeline_scale_f32;
 8900        }
 8901        return nullptr;
 8902    case GGML_OP_SQR:
 8903        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8904            return ctx->device->pipeline_sqr_f32;
 8905        }
 8906        return nullptr;
 8907    case GGML_OP_SQRT:
 8908        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8909            return ctx->device->pipeline_sqrt_f32;
 8910        }
 8911        return nullptr;
 8912    case GGML_OP_SIN:
 8913        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8914            return ctx->device->pipeline_sin_f32;
 8915        }
 8916        return nullptr;
 8917    case GGML_OP_COS:
 8918        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8919            return ctx->device->pipeline_cos_f32;
 8920        }
 8921        return nullptr;
 8922    case GGML_OP_LOG:
 8923        if (src0->type == dst->type &&
 8924            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
 8925            return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
 8926        }
 8927        return nullptr;
 8928    case GGML_OP_TRI:
 8929        if (src0->type == dst->type &&
 8930            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
 8931            return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
 8932        }
 8933        return nullptr;
 8934    case GGML_OP_DIAG:
 8935        if (src0->type == dst->type &&
 8936            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
 8937            return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
 8938        }
 8939        return nullptr;
 8940    case GGML_OP_CLAMP:
 8941        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8942            return ctx->device->pipeline_clamp_f32;
 8943        }
 8944        return nullptr;
 8945    case GGML_OP_PAD:
 8946        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8947            return ctx->device->pipeline_pad_f32;
 8948        }
 8949        return nullptr;
 8950    case GGML_OP_ROLL:
 8951        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8952            return ctx->device->pipeline_roll_f32;
 8953        }
 8954        return nullptr;
 8955    case GGML_OP_REPEAT:
 8956        if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
 8957            return ctx->device->pipeline_repeat_f32;
 8958        }
 8959        return nullptr;
 8960    case GGML_OP_REPEAT_BACK:
 8961        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8962            return ctx->device->pipeline_repeat_back_f32;
 8963        }
 8964        return nullptr;
 8965    case GGML_OP_CPY:
 8966    case GGML_OP_CONT:
 8967    case GGML_OP_DUP:
 8968        return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
 8969    case GGML_OP_SET_ROWS:
 8970        if (src1->type == GGML_TYPE_I64) {
 8971            return ctx->device->pipeline_set_rows_i64[dst->type];
 8972        } else {
 8973            return ctx->device->pipeline_set_rows_i32[dst->type];
 8974        }
 8975    case GGML_OP_SILU_BACK:
 8976        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8977            return ctx->device->pipeline_silu_back_f32;
 8978        }
 8979        return nullptr;
 8980    case GGML_OP_NORM:
 8981        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8982            return ctx->device->pipeline_norm_f32;
 8983        }
 8984        return nullptr;
 8985    case GGML_OP_GROUP_NORM:
 8986        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8987            return ctx->device->pipeline_group_norm_f32;
 8988        }
 8989        return nullptr;
 8990    case GGML_OP_RMS_NORM:
 8991        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 8992            if (ctx->do_add_rms_partials) {
 8993                return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
 8994            } else {
 8995                return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
 8996            }
 8997        }
 8998        return nullptr;
 8999    case GGML_OP_RMS_NORM_BACK:
 9000        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9001            return ctx->device->pipeline_rms_norm_back_f32;
 9002        }
 9003        return nullptr;
 9004    case GGML_OP_L2_NORM:
 9005        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9006            return ctx->device->pipeline_l2_norm_f32;
 9007        }
 9008        return nullptr;
 9009    case GGML_OP_UNARY:
 9010        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
 9011            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
 9012            (src0->type != dst->type)) {
 9013            return nullptr;
 9014        }
 9015
 9016        switch (ggml_get_unary_op(dst)) {
 9017            case GGML_UNARY_OP_EXP:
 9018                return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
 9019            case GGML_UNARY_OP_SILU:
 9020                return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
 9021            case GGML_UNARY_OP_GELU:
 9022                return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
 9023            case GGML_UNARY_OP_GELU_ERF:
 9024                return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
 9025            case GGML_UNARY_OP_GELU_QUICK:
 9026                return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
 9027            case GGML_UNARY_OP_RELU:
 9028                return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
 9029            case GGML_UNARY_OP_XIELU:
 9030                return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
 9031            case GGML_UNARY_OP_NEG:
 9032                return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
 9033            case GGML_UNARY_OP_TANH:
 9034                return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
 9035            case GGML_UNARY_OP_SIGMOID:
 9036                return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
 9037            case GGML_UNARY_OP_HARDSIGMOID:
 9038                return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
 9039            case GGML_UNARY_OP_HARDSWISH:
 9040                return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
 9041            case GGML_UNARY_OP_ABS:
 9042                return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
 9043            case GGML_UNARY_OP_SOFTPLUS:
 9044                return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
 9045            case GGML_UNARY_OP_STEP:
 9046                return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
 9047            case GGML_UNARY_OP_ROUND:
 9048                return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
 9049            case GGML_UNARY_OP_CEIL:
 9050                return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
 9051            case GGML_UNARY_OP_FLOOR:
 9052                return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
 9053            case GGML_UNARY_OP_TRUNC:
 9054                return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
 9055            default:
 9056                break;
 9057        }
 9058        return nullptr;
 9059    case GGML_OP_GLU:
 9060        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
 9061            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
 9062            (src0->type != dst->type)) {
 9063            return nullptr;
 9064        }
 9065
 9066        switch (ggml_get_glu_op(dst)) {
 9067            case GGML_GLU_OP_GEGLU:
 9068                return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
 9069            case GGML_GLU_OP_REGLU:
 9070                return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
 9071            case GGML_GLU_OP_SWIGLU:
 9072                return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
 9073            case GGML_GLU_OP_SWIGLU_OAI:
 9074                return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
 9075            case GGML_GLU_OP_GEGLU_ERF:
 9076                return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
 9077            case GGML_GLU_OP_GEGLU_QUICK:
 9078                return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
 9079            default:
 9080                break;
 9081        }
 9082        return nullptr;
 9083    case GGML_OP_DIAG_MASK_INF:
 9084        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9085            return ctx->device->pipeline_diag_mask_inf_f32;
 9086        }
 9087        return nullptr;
 9088    case GGML_OP_SOFT_MAX:
 9089        GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
 9090        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
 9091
 9092        if (ctx->num_additional_fused_ops) {
 9093            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
 9094            GGML_ASSERT(idx < num_topk_moe_pipelines);
 9095            // use n_experts from push constant if it's not equal to the power of two spec constant
 9096            bool use_push = dst->ne[0] != (1u << idx);
 9097            return ctx->device->pipeline_topk_moe[idx][use_push];
 9098        }
 9099
 9100        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
 9101            return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
 9102        }
 9103        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
 9104            return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
 9105        }
 9106        return nullptr;
 9107    case GGML_OP_SOFT_MAX_BACK:
 9108        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9109            return ctx->device->pipeline_soft_max_back_f32;
 9110        }
 9111        return nullptr;
 9112    case GGML_OP_ROPE:
 9113    case GGML_OP_ROPE_BACK:
 9114        {
 9115            const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
 9116            const int mode = ((const int32_t *) rope->op_params)[2];
 9117            const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 9118            const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
 9119            const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 9120
 9121            if (is_neox) {
 9122                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9123                    return ctx->device->pipeline_rope_neox_f32;
 9124                }
 9125                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
 9126                    return ctx->device->pipeline_rope_neox_f32_f16;
 9127                }
 9128                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 9129                    return ctx->device->pipeline_rope_neox_f16;
 9130                }
 9131            } else if (is_mrope && !is_vision) {
 9132                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9133                    return ctx->device->pipeline_rope_multi_f32;
 9134                }
 9135                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
 9136                    return ctx->device->pipeline_rope_multi_f32_f16;
 9137                }
 9138                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 9139                    return ctx->device->pipeline_rope_multi_f16;
 9140                }
 9141            } else if (is_vision) {
 9142                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9143                    return ctx->device->pipeline_rope_vision_f32;
 9144                }
 9145                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 9146                    return ctx->device->pipeline_rope_vision_f16;
 9147                }
 9148            } else {
 9149                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9150                    return ctx->device->pipeline_rope_norm_f32;
 9151                }
 9152                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
 9153                    return ctx->device->pipeline_rope_norm_f32_f16;
 9154                }
 9155                if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 9156                    return ctx->device->pipeline_rope_norm_f16;
 9157                }
 9158            }
 9159            return nullptr;
 9160        }
 9161    case GGML_OP_SUM:
 9162    case GGML_OP_SUM_ROWS:
 9163    case GGML_OP_MEAN:
 9164        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9165            return ctx->device->pipeline_sum_rows_f32;
 9166        }
 9167        return nullptr;
 9168    case GGML_OP_CUMSUM:
 9169        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9170            if (src0->ne[0] <= 512) {
 9171                return ctx->device->pipeline_cumsum_small_f32;
 9172            } else {
 9173                return ctx->device->pipeline_cumsum_f32;
 9174            }
 9175        }
 9176        return nullptr;
 9177    case GGML_OP_SOLVE_TRI:
 9178        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9179
 9180            vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
 9181
 9182            vk_pipeline pipeline = nullptr;
 9183
 9184            {
 9185                std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
 9186                auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
 9187                if (it != ctx->device->pipeline_solve_tri_f32.end()) {
 9188                    pipeline = it->second;
 9189                } else {
 9190                    ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
 9191                }
 9192            }
 9193
 9194            return pipeline;
 9195        }
 9196        return nullptr;
 9197    case GGML_OP_ARGMAX:
 9198        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
 9199            return ctx->device->pipeline_argmax_f32;
 9200        }
 9201        return nullptr;
 9202    case GGML_OP_COUNT_EQUAL:
 9203        if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
 9204            return ctx->device->pipeline_count_equal_i32;
 9205        }
 9206        return nullptr;
 9207    case GGML_OP_IM2COL:
 9208        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9209            return ctx->device->pipeline_im2col_f32;
 9210        }
 9211        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
 9212            return ctx->device->pipeline_im2col_f32_f16;
 9213        }
 9214        return nullptr;
 9215    case GGML_OP_IM2COL_3D:
 9216        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9217            return ctx->device->pipeline_im2col_3d_f32;
 9218        }
 9219        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
 9220            return ctx->device->pipeline_im2col_3d_f32_f16;
 9221        }
 9222        return nullptr;
 9223    case GGML_OP_TIMESTEP_EMBEDDING:
 9224        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9225            return ctx->device->pipeline_timestep_embedding_f32;
 9226        }
 9227        return nullptr;
 9228    case GGML_OP_CONV_TRANSPOSE_1D:
 9229        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9230            return ctx->device->pipeline_conv_transpose_1d_f32;
 9231        }
 9232        return nullptr;
 9233    case GGML_OP_POOL_2D:
 9234        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9235            return ctx->device->pipeline_pool2d_f32;
 9236        }
 9237        return nullptr;
 9238    case GGML_OP_RWKV_WKV6:
 9239        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9240            return ctx->device->pipeline_rwkv_wkv6_f32;
 9241        }
 9242        return nullptr;
 9243    case GGML_OP_RWKV_WKV7:
 9244        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9245            return ctx->device->pipeline_rwkv_wkv7_f32;
 9246        }
 9247        return nullptr;
 9248    case GGML_OP_SSM_SCAN:
 9249        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9250            const uint32_t d_state = src0->ne[0];
 9251            if (d_state == 128) {
 9252                return ctx->device->pipeline_ssm_scan_f32_d128;
 9253            } else if (d_state == 256) {
 9254                return ctx->device->pipeline_ssm_scan_f32_d256;
 9255            }
 9256        }
 9257        return nullptr;
 9258    case GGML_OP_SSM_CONV:
 9259        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9260            return ctx->device->pipeline_ssm_conv_f32;
 9261        }
 9262        return nullptr;
 9263    case GGML_OP_OPT_STEP_ADAMW:
 9264        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9265            return ctx->device->pipeline_opt_step_adamw_f32;
 9266        }
 9267        return nullptr;
 9268    case GGML_OP_OPT_STEP_SGD:
 9269        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9270            return ctx->device->pipeline_opt_step_sgd_f32;
 9271        }
 9272        return nullptr;
 9273    case GGML_OP_LEAKY_RELU:
 9274        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9275            return ctx->device->pipeline_leaky_relu_f32;
 9276        }
 9277        return nullptr;
 9278    case GGML_OP_CONV_2D:
 9279    case GGML_OP_CONV_TRANSPOSE_2D:
 9280        if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9281            uint32_t K = dst->ne[2]; // Cout
 9282            uint32_t NPQ = dst->ne[3] * dst->ne[1] * dst->ne[0]; // N * OH * OW
 9283            vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, K, NPQ);
 9284
 9285            bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
 9286            uint32_t KW = (uint32_t)src0->ne[0];
 9287            uint32_t KH = (uint32_t)src0->ne[1];
 9288            uint32_t s0 = (uint32_t)(ggml_get_op_params_i32(dst, 0));
 9289            uint32_t s1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 1) : s0;
 9290            uint32_t p0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 2) : 0;
 9291            uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
 9292            uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
 9293            uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
 9294            vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
 9295
 9296            std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
 9297            if (op == GGML_OP_CONV_2D) {
 9298                if (src0->type == GGML_TYPE_F32) {
 9299                    pipelines = &ctx->device->pipeline_conv2d_f32[shape];
 9300                } else if (src0->type == GGML_TYPE_F16) {
 9301                    pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];
 9302                }
 9303            } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
 9304                if (src0->type == GGML_TYPE_F32) {
 9305                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];
 9306                } else if (src0->type == GGML_TYPE_F16) {
 9307                    pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
 9308                }
 9309            }
 9310
 9311            vk_pipeline pipeline = nullptr;
 9312
 9313            {
 9314                std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
 9315                auto it = pipelines->find(conv2d_pipeline_state);
 9316                if (it != pipelines->end()) {
 9317                    pipeline = it->second;
 9318                } else {
 9319                    (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
 9320                }
 9321            }
 9322
 9323            return pipeline;
 9324        }
 9325        return nullptr;
 9326    case GGML_OP_CONV_2D_DW:
 9327        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9328            if (ggml_is_contiguous(src1)) {
 9329                return ctx->device->pipeline_conv2d_dw_whcn_f32;
 9330            } else if (ggml_is_contiguous_channels(src1)) {
 9331                return ctx->device->pipeline_conv2d_dw_cwhn_f32;
 9332            }
 9333        } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
 9334            if (ggml_is_contiguous(src1)) {
 9335                return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
 9336            } else if (ggml_is_contiguous_channels(src1)) {
 9337                return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
 9338            }
 9339        }
 9340        return nullptr;
 9341    case GGML_OP_ADD1:
 9342        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
 9343            return ctx->device->pipeline_add1_f16_f16;
 9344        }
 9345        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
 9346            return ctx->device->pipeline_add1_f16_f32;
 9347        }
 9348        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
 9349            return ctx->device->pipeline_add1_f32_f32;
 9350        }
 9351        return nullptr;
 9352    case GGML_OP_ARANGE:
 9353        if (dst->type == GGML_TYPE_F32) {
 9354            return ctx->device->pipeline_arange_f32;
 9355        }
 9356        return nullptr;
 9357    case GGML_OP_FILL:
 9358        if (dst->type == GGML_TYPE_F32) {
 9359            return ctx->device->pipeline_fill_f32;
 9360        }
 9361        return nullptr;
 9362    default:
 9363        return nullptr;
 9364    }
 9365
 9366    GGML_UNUSED(src2);
 9367}
 9368
 9369template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 9370    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
 9371    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 9372
 9373    p.misalign_offsets = (a_offset << 16) | d_offset;
 9374
 9375    GGML_UNUSED(src1);
 9376    GGML_UNUSED(src2);
 9377    GGML_UNUSED(src3);
 9378}
 9379
 9380template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 9381    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
 9382    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 9383
 9384    p.misalign_offsets = (a_offset << 16) | d_offset;
 9385
 9386    GGML_UNUSED(src1);
 9387    GGML_UNUSED(src2);
 9388    GGML_UNUSED(src3);
 9389}
 9390
 9391template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 9392    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
 9393    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 9394
 9395    p.misalign_offsets = (a_offset << 16) | d_offset;
 9396
 9397    GGML_UNUSED(src1);
 9398    GGML_UNUSED(src2);
 9399    GGML_UNUSED(src3);
 9400}
 9401
 9402template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 9403    const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
 9404    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 9405
 9406    p.misalign_offsets = (a_offset << 16) | d_offset;
 9407
 9408    GGML_UNUSED(src0);
 9409    GGML_UNUSED(src2);
 9410    GGML_UNUSED(src3);
 9411}
 9412
 9413template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 9414    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
 9415    const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
 9416    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 9417
 9418    GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
 9419
 9420    p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
 9421
 9422    GGML_UNUSED(src2);
 9423    GGML_UNUSED(src3);
 9424}
 9425
 9426template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
 9427    const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
 9428    const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
 9429
 9430    p.a_offset = a_offset;
 9431    p.d_offset = d_offset;
 9432
 9433    GGML_UNUSED(src1);
 9434    GGML_UNUSED(src2);
 9435    GGML_UNUSED(src3);
 9436}
 9437
 9438template<typename PC>
 9439static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
 9440    VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
 9441    if (src1 != nullptr) {
 9442        std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
 9443    }
 9444    if (src2 != nullptr) {
 9445        std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
 9446    }
 9447    if (src3 != nullptr) {
 9448        std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
 9449    }
 9450    std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
 9451    std::cerr << "), " << ggml_op_name(op) << ")");
 9452    GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT
 9453    GGML_ASSERT(dst->buffer != nullptr);
 9454    const uint64_t ne00 = src0->ne[0];
 9455    const uint64_t ne01 = src0->ne[1];
 9456    const uint64_t ne02 = src0->ne[2];
 9457    const uint64_t ne03 = src0->ne[3];
 9458
 9459    const bool use_src1 = src1 != nullptr;
 9460    const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
 9461    const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
 9462    const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
 9463    const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
 9464
 9465    const bool use_src2 = src2 != nullptr;
 9466    const bool use_src3 = src3 != nullptr;
 9467
 9468    init_pushconst_fastdiv(pc);
 9469
 9470    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
 9471
 9472    if (pipeline == nullptr) {
 9473        std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
 9474        if (src1 != nullptr) {
 9475            std::cerr << " and " << ggml_type_name(src1->type);
 9476        }
 9477        std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
 9478        GGML_ABORT("fatal error");
 9479    }
 9480
 9481    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 9482
 9483    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
 9484    vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
 9485    vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
 9486    vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
 9487    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
 9488
 9489    // Compute misalignment offset for descriptors and store it in in push constants.
 9490    init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
 9491
 9492    std::array<uint32_t, 3> elements;
 9493
 9494    switch (op) {
 9495    case GGML_OP_NORM:
 9496    case GGML_OP_RMS_NORM_BACK:
 9497    case GGML_OP_L2_NORM:
 9498    case GGML_OP_SOFT_MAX:
 9499    case GGML_OP_SOFT_MAX_BACK:
 9500    case GGML_OP_SUM_ROWS:
 9501    case GGML_OP_CUMSUM:
 9502    case GGML_OP_MEAN:
 9503    case GGML_OP_ARGMAX:
 9504        {
 9505            const uint32_t nr = ggml_nrows(src0);
 9506            if (nr > 262144) {
 9507                elements = { 512, 512, CEIL_DIV(nr, 262144) };
 9508            } else if (nr > 512) {
 9509                elements = { 512, CEIL_DIV(nr, 512), 1 };
 9510            } else {
 9511                elements = { nr, 1, 1 };
 9512            }
 9513        } break;
 9514    case GGML_OP_SOLVE_TRI:
 9515        {
 9516            uint32_t nr = (uint32_t)(ne02 * ne03);
 9517            if (nr > 262144) {
 9518                elements = { 512, 512, CEIL_DIV(nr, 262144) };
 9519            } else if (nr > 512) {
 9520                elements = { 512, CEIL_DIV(nr, 512), 1 };
 9521            } else {
 9522                elements = { nr, 1, 1 };
 9523            }
 9524        }
 9525        break;
 9526    case GGML_OP_RMS_NORM:
 9527        if (ctx->do_add_rms_partials) {
 9528            // Run one element per thread, 128 threads per workgroup
 9529            elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
 9530        } else {
 9531            elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
 9532        }
 9533        break;
 9534
 9535    case GGML_OP_SUM:
 9536        // We use GGML_OP_SUM_ROWS with 1 row.
 9537        elements = { 1, 1, 1 };
 9538        break;
 9539    case GGML_OP_GROUP_NORM:
 9540        {
 9541            const uint32_t num_groups = dst->op_params[0];
 9542            elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
 9543        } break;
 9544    case GGML_OP_DIAG_MASK_INF:
 9545        elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
 9546        break;
 9547    case GGML_OP_ROPE:
 9548    case GGML_OP_ROPE_BACK:
 9549        {
 9550            uint32_t nrows = (uint32_t)ggml_nrows(src0);
 9551            uint32_t z = 1;
 9552            if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
 9553                z = CEIL_DIV(nrows, 32768);
 9554                nrows = 32768;
 9555            }
 9556            elements = { nrows, (uint32_t)ne00, z };
 9557
 9558        } break;
 9559    case GGML_OP_GET_ROWS:
 9560        elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
 9561        elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
 9562        elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
 9563        break;
 9564    case GGML_OP_ARGSORT:
 9565        GGML_ASSERT(0);
 9566        break;
 9567    case GGML_OP_IM2COL:
 9568        {
 9569            const bool is_2D = dst->op_params[6] == 1;
 9570
 9571            const uint32_t IC = src1->ne[is_2D ? 2 : 1];
 9572
 9573            const uint32_t KH = is_2D ? src0->ne[1] : 1;
 9574            const uint32_t KW =         src0->ne[0];
 9575
 9576            const uint32_t OH = is_2D ? dst->ne[2] : 1;
 9577            const uint32_t OW =         dst->ne[1];
 9578
 9579            const uint32_t batch = src1->ne[is_2D ? 3 : 2];
 9580
 9581            elements = { OW * KW * KH, OH, batch * IC };
 9582            elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
 9583            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
 9584        } break;
 9585    case GGML_OP_IM2COL_3D:
 9586        {
 9587            const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
 9588
 9589            const uint32_t N  = ne13 / IC;
 9590
 9591            const uint32_t KD = ne02;
 9592            const uint32_t KH = ne01;
 9593            const uint32_t KW = ne00;
 9594
 9595            const uint32_t OD = dst->ne[3] / N;
 9596            const uint32_t OH = dst->ne[2];
 9597            const uint32_t OW = dst->ne[1];
 9598
 9599            const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
 9600            const uint32_t N_OD_OH = N*OD*OH;
 9601
 9602            elements = { IC_KD_KH_KW, OW, N_OD_OH };
 9603            elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
 9604        } break;
 9605    case GGML_OP_TIMESTEP_EMBEDDING:
 9606        {
 9607            const uint32_t dim = dst->op_params[0];
 9608            uint32_t half_ceil = (dim + 1) / 2;
 9609            elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
 9610        } break;
 9611    case GGML_OP_CONV_TRANSPOSE_1D:
 9612        {
 9613            elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
 9614        } break;
 9615    case GGML_OP_POOL_2D:
 9616        {
 9617            const uint32_t N = dst->ne[3];
 9618            const uint32_t OC = dst->ne[2];
 9619            const uint32_t OH = dst->ne[1];
 9620            const uint32_t OW = dst->ne[0];
 9621            elements = { N * OC * OH * OW, 1, 1};
 9622        } break;
 9623    case GGML_OP_CONV_2D:
 9624    case GGML_OP_CONV_TRANSPOSE_2D:
 9625        if constexpr (std::is_same_v<PC, vk_op_conv2d_push_constants>) {
 9626            const uint32_t NPQ = pc.N * pc.OH * pc.OW;
 9627            const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.Cout, NPQ);
 9628            const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
 9629
 9630            elements = { pc.Cout, NPQ_blocks, 1 };
 9631            if (elements[1] > 512) {
 9632                elements[2] = CEIL_DIV(elements[1], 512);
 9633                elements[1] = 512;
 9634            }
 9635        } else {
 9636            GGML_ABORT("invalid push constant type for CONV_2D");
 9637        }
 9638        break;
 9639    case GGML_OP_ADD:
 9640    case GGML_OP_SUB:
 9641    case GGML_OP_DIV:
 9642    case GGML_OP_MUL:
 9643    case GGML_OP_ADD1:
 9644    case GGML_OP_ARANGE:
 9645    case GGML_OP_FILL:
 9646    case GGML_OP_SCALE:
 9647    case GGML_OP_SQR:
 9648    case GGML_OP_SQRT:
 9649    case GGML_OP_SIN:
 9650    case GGML_OP_COS:
 9651    case GGML_OP_LOG:
 9652    case GGML_OP_TRI:
 9653    case GGML_OP_DIAG:
 9654    case GGML_OP_CLAMP:
 9655    case GGML_OP_PAD:
 9656    case GGML_OP_ROLL:
 9657    case GGML_OP_REPEAT:
 9658    case GGML_OP_REPEAT_BACK:
 9659    case GGML_OP_CPY:
 9660    case GGML_OP_CONCAT:
 9661    case GGML_OP_UPSCALE:
 9662    case GGML_OP_UNARY:
 9663    case GGML_OP_GLU:
 9664    case GGML_OP_CONV_2D_DW:
 9665        {
 9666            uint32_t ne = ggml_nelements(dst);
 9667            if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
 9668                // Convert from number of logical elements to 2- or 4-byte units.
 9669                ne /= ggml_blck_size(src0->type);
 9670                if ((ggml_type_size(src0->type) % 4) == 0) {
 9671                    ne *= ggml_type_size(src0->type) / 4;
 9672                } else {
 9673                    ne *= ggml_type_size(src0->type) / 2;
 9674                }
 9675            }
 9676            // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
 9677            // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
 9678            // So divide by block size here before splitting into 512x512 groups.
 9679            if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
 9680                ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
 9681            }
 9682            if (ne > 262144) {
 9683                elements = { 512, 512, CEIL_DIV(ne, 262144) };
 9684            } else if (ne > 512) {
 9685                elements = { 512, CEIL_DIV(ne, 512), 1 };
 9686            } else {
 9687                elements = { ne, 1, 1 };
 9688            }
 9689
 9690            if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||
 9691                pipeline == ctx->device->pipeline_cpy_transpose_16) {
 9692                // 32x32 tiles
 9693                elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);
 9694                elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);
 9695                elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);
 9696                elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);
 9697                elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
 9698                elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
 9699            }
 9700        } break;
 9701    case GGML_OP_ADD_ID:
 9702        {
 9703            elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
 9704        } break;
 9705    case GGML_OP_SET_ROWS:
 9706        {
 9707            uint32_t ne = ggml_nelements(src0);
 9708            if (ggml_is_quantized(dst->type)) {
 9709                // quants run 32 threads each doing QUANT_K elements
 9710                ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
 9711            } else {
 9712                // scalar types do one element per thread, running 512 threads
 9713                ne = CEIL_DIV(ne, 512);
 9714            }
 9715            if (ne > 262144) {
 9716                elements = { 512, 512, CEIL_DIV(ne, 262144) };
 9717            } else if (ne > 512) {
 9718                elements = { 512, CEIL_DIV(ne, 512), 1 };
 9719            } else {
 9720                elements = { ne, 1, 1 };
 9721            }
 9722        }
 9723        break;
 9724    case GGML_OP_SSM_CONV:
 9725        {
 9726            const uint32_t nr  = src0->ne[1];
 9727            const uint32_t n_t = dst->ne[1];
 9728            const uint32_t n_s = dst->ne[2];
 9729            elements = { nr, n_t, n_s };
 9730        }
 9731        break;
 9732    default:
 9733        elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
 9734        break;
 9735    }
 9736
 9737    if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
 9738        vk_subbuffer a_buf = src0_buf;
 9739        if (ctx->do_add_rms_partials) {
 9740            a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
 9741        }
 9742        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
 9743            { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
 9744    } else if (op == GGML_OP_GLU) {
 9745        // Empty src1 is possible in glu, but the shader needs a buffer
 9746        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
 9747        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
 9748    } else if (op == GGML_OP_SOFT_MAX) {
 9749        // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
 9750        vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
 9751        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
 9752        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
 9753    } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
 9754        // Empty src2 and src3 is possible in rope, but the shader needs a buffer
 9755        vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
 9756        vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;
 9757        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);
 9758    } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
 9759        if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
 9760            // buffer device address path doesn't use dst buffer
 9761            dst_buf.size = 1;
 9762        }
 9763        // im2col uses only src1 and dst buffers
 9764        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
 9765    } else if (op == GGML_OP_COUNT_EQUAL) {
 9766        // count_equal assumes that destination buffer is initialized with zeroes
 9767        ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size);
 9768        ggml_vk_sync_buffers(ctx, subctx);
 9769        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
 9770    } else if (op == GGML_OP_OPT_STEP_SGD) {
 9771        // OPT_STEP_SGD works on src0, it does not need dst
 9772        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
 9773    } else if (use_src3) {
 9774        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);
 9775    } else if (use_src2) {
 9776        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
 9777    } else if (use_src1) {
 9778        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
 9779    } else {
 9780        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
 9781    }
 9782}
 9783
 9784static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 9785    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9786    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9787    const uint32_t dst_type_size = ggml_type_size(dst->type);
 9788
 9789    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, {
 9790        (uint32_t)ggml_nelements(src0),
 9791        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
 9792        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
 9793        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
 9794        0,
 9795        0.0f, 0.0f, 0,
 9796    });
 9797}
 9798
 9799static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 9800    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9801    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9802    const uint32_t dst_type_size = ggml_type_size(dst->type);
 9803
 9804    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
 9805    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
 9806    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
 9807    int offset = dst->op_params[3] / 4; // offset in bytes
 9808
 9809    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
 9810        (uint32_t)ggml_nelements(src0),
 9811        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
 9812        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
 9813        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] /  dst_type_size,
 9814        0,
 9815        0.0f, 0.0f, offset,
 9816    });
 9817}
 9818
 9819static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
 9820    const ggml_tensor *first_node = cgraph->nodes[node_idx];
 9821    const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
 9822
 9823    // Make a list of all the tensors used by the op.
 9824    // Last element of the list is the dest tensor.
 9825    const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
 9826    uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
 9827    uint32_t num_tensors = num_srcs + 1;
 9828    GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
 9829
 9830    tensors[0] = first_node->src[0];
 9831    tensors[1] = first_node->src[1];
 9832    for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
 9833        // check whether the previous result is src[0] or src[1]
 9834        if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
 9835            tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
 9836        } else {
 9837            tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
 9838        }
 9839    }
 9840    tensors[num_srcs] = dst;
 9841
 9842    vk_op_multi_add_push_constants pc;
 9843    pc.ne20 = (uint32_t)dst->ne[0];
 9844    pc.ne21 = (uint32_t)dst->ne[1];
 9845    pc.ne22 = (uint32_t)dst->ne[2];
 9846    pc.ne23 = (uint32_t)dst->ne[3];
 9847
 9848    for (uint32_t i = 0; i < num_tensors; ++i) {
 9849        const ggml_tensor *t = tensors[i];
 9850        pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
 9851        pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
 9852        pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
 9853        pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
 9854    }
 9855    pc.rms_partials = ctx->do_add_rms_partials;
 9856
 9857    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
 9858
 9859    if (pipeline == nullptr) {
 9860        std::cerr << "ggml_vulkan: Error: Missing multi_add";
 9861        GGML_ABORT("fatal error");
 9862    }
 9863
 9864    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 9865
 9866    ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
 9867    vk_buffer buf[MAX_PARAMETER_COUNT];
 9868    size_t offset[MAX_PARAMETER_COUNT];
 9869    bool uma[MAX_PARAMETER_COUNT];
 9870
 9871    for (uint32_t i = 0; i < num_tensors; ++i) {
 9872        buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
 9873        buf[i] = nullptr;
 9874        offset[i] = 0;
 9875        uma[i] = false;
 9876
 9877        if (ctx->device->uma) {
 9878            ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
 9879            uma[i] = buf[i] != nullptr;
 9880        }
 9881        if (!uma[i]) {
 9882            buf[i] = buf_ctx[i]->dev_buffer;
 9883            offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
 9884        }
 9885        GGML_ASSERT(buf[i] != nullptr);
 9886    }
 9887    // If any remaining descriptors are unused, just point them at src[0]
 9888    for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
 9889        buf[i] = buf[0];
 9890        offset[i] = 0;
 9891    }
 9892    if (ctx->do_add_rms_partials) {
 9893        buf[num_tensors] = ctx->prealloc_add_rms_partials;
 9894        offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
 9895    }
 9896
 9897    std::array<uint32_t, 3> elements;
 9898
 9899    uint32_t ne = ggml_nelements(dst);
 9900    if (ne > 262144) {
 9901        elements = { 512, 512, CEIL_DIV(ne, 262144) };
 9902    } else if (ne > 512) {
 9903        elements = { 512, CEIL_DIV(ne, 512), 1 };
 9904    } else {
 9905        elements = { ne, 1, 1 };
 9906    }
 9907
 9908    static_assert(MAX_PARAMETER_COUNT == 12);
 9909    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
 9910        {
 9911            ggml_vk_subbuffer(ctx, buf[0], offset[0]),
 9912            ggml_vk_subbuffer(ctx, buf[1], offset[1]),
 9913            ggml_vk_subbuffer(ctx, buf[2], offset[2]),
 9914            ggml_vk_subbuffer(ctx, buf[3], offset[3]),
 9915            ggml_vk_subbuffer(ctx, buf[4], offset[4]),
 9916            ggml_vk_subbuffer(ctx, buf[5], offset[5]),
 9917            ggml_vk_subbuffer(ctx, buf[6], offset[6]),
 9918            ggml_vk_subbuffer(ctx, buf[7], offset[7]),
 9919            ggml_vk_subbuffer(ctx, buf[8], offset[8]),
 9920            ggml_vk_subbuffer(ctx, buf[9], offset[9]),
 9921            ggml_vk_subbuffer(ctx, buf[10], offset[10]),
 9922            ggml_vk_subbuffer(ctx, buf[11], offset[11]),
 9923        }, pc, elements);
 9924}
 9925
 9926static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 9927    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9928    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9929    const uint32_t dst_type_size = ggml_type_size(dst->type);
 9930
 9931    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, {
 9932        (uint32_t)ggml_nelements(src0),
 9933        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
 9934        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
 9935        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
 9936        0,
 9937        0.0f, 0.0f, ctx->do_add_rms_partials,
 9938    });
 9939}
 9940
 9941static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 9942    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9943    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9944    const uint32_t dst_type_size = ggml_type_size(dst->type);
 9945
 9946    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, {
 9947        (uint32_t)ggml_nelements(src0),
 9948        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
 9949        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
 9950        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
 9951        0,
 9952        0.0f, 0.0f, 0,
 9953    });
 9954}
 9955
 9956static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 9957    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9958    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9959    const uint32_t dst_type_size = ggml_type_size(dst->type);
 9960
 9961    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, {
 9962        (uint32_t)ggml_nelements(src0),
 9963        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
 9964        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
 9965        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
 9966        0,
 9967        0.0f, 0.0f, 0,
 9968    });
 9969}
 9970
 9971static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 9972    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9973    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9974    const uint32_t dst_type_size = ggml_type_size(dst->type);
 9975
 9976    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, {
 9977        (uint32_t)ggml_nelements(src0),
 9978        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
 9979        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
 9980        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
 9981        0,
 9982        0.0f, 0.0f, 0,
 9983    });
 9984}
 9985
 9986static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
 9987    const uint32_t src0_type_size = ggml_type_size(src0->type);
 9988    const uint32_t src1_type_size = ggml_type_size(src1->type);
 9989    const uint32_t src2_type_size = ggml_type_size(src2->type);
 9990
 9991    ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, {
 9992        (uint32_t)dst->ne[0],
 9993        (uint32_t)dst->ne[1],
 9994        (uint32_t)src0->nb[1] / src0_type_size,
 9995        (uint32_t)src0->nb[2] / src0_type_size,
 9996        (uint32_t)src1->nb[1] / src1_type_size,
 9997        (uint32_t)src2->nb[1] / src2_type_size,
 9998    });
 9999}
10000
10001static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
10002    GGML_ASSERT(version == 6 || version == 7);
10003    int num_srcs = version == 6 ? 6 : 7;
10004
10005    for (int i = 0; i < num_srcs; i++) {
10006        GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
10007    }
10008
10009    GGML_ASSERT(dst->buffer != nullptr);
10010
10011    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
10012    GGML_ASSERT(pipeline != nullptr);
10013
10014    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10015
10016    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10017    vk_subbuffer src_buf[7] = {};
10018    for (int i = 0; i < num_srcs; i++) {
10019        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
10020    }
10021
10022    std::array<uint32_t, 3> elements = {
10023        (uint32_t)(pc.B * pc.H),
10024        1,
10025        1
10026    };
10027
10028    if (version == 6) {
10029        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10030            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
10031            pc, elements);
10032    } else if (version == 7) {
10033        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10034            {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
10035            pc, elements);
10036    } else {
10037        // shouldn't happen
10038        GGML_ASSERT(false);
10039    }
10040}
10041
10042static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10043    const size_t seq_length = dst->src[0]->ne[2];
10044    const size_t n_embed = dst->ne[0];
10045    const size_t n_heads = dst->src[0]->ne[1];
10046    const size_t n_seqs = dst->src[5]->ne[1];
10047
10048    ggml_vk_op_f32_wkv(
10049        ctx, subctx, dst,
10050        {
10051            (uint32_t)n_seqs,
10052            (uint32_t)seq_length,
10053            (uint32_t)n_embed,
10054            (uint32_t)n_heads,
10055        },
10056        6
10057    );
10058}
10059
10060static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10061    const size_t seq_length = dst->src[0]->ne[2];
10062    const size_t n_embed = dst->ne[0];
10063    const size_t n_heads = dst->src[0]->ne[1];
10064    const size_t n_seqs = dst->src[6]->ne[1];
10065
10066    ggml_vk_op_f32_wkv(
10067        ctx, subctx, dst,
10068        {
10069            (uint32_t)n_seqs,
10070            (uint32_t)seq_length,
10071            (uint32_t)n_embed,
10072            (uint32_t)n_heads,
10073        },
10074        7
10075    );
10076}
10077
10078static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10079    const ggml_tensor * src0 = dst->src[0];
10080    const ggml_tensor * src1 = dst->src[1];
10081    const ggml_tensor * src2 = dst->src[2];
10082    const ggml_tensor * src3 = dst->src[3];
10083    const ggml_tensor * src4 = dst->src[4];
10084    const ggml_tensor * src5 = dst->src[5];
10085
10086    GGML_ASSERT(dst->buffer != nullptr);
10087
10088    const uint32_t head_dim = src0->ne[1];
10089    const uint32_t n_head = src1->ne[1];
10090    const uint32_t n_group = src4->ne[1];
10091    const uint32_t n_tok = src1->ne[2];
10092    const uint32_t n_seq = src1->ne[3];
10093
10094    bool is_mamba2 = (src3->nb[1] == sizeof(float));
10095    GGML_ASSERT(is_mamba2);
10096
10097    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
10098    GGML_ASSERT(pipeline != nullptr);
10099
10100    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10101
10102    const int64_t s_off = ggml_nelements(src1) * sizeof(float);
10103
10104    const vk_op_ssm_scan_push_constants pc = {
10105        (uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
10106        (uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
10107        (uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
10108        (uint32_t)src3->nb[1],
10109        (uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
10110        (uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
10111        (uint32_t)s_off,
10112        n_head, head_dim, n_group, n_tok
10113    };
10114
10115    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10116    vk_subbuffer src_buf[7] = {};
10117    for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
10118        src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
10119    }
10120
10121    std::array<uint32_t, 3> elements;
10122
10123    const uint32_t d_state = src0->ne[0];
10124    uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
10125    const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
10126    const uint32_t num_workgroups_y = n_seq;
10127    elements = { num_workgroups_x, num_workgroups_y, 1 };
10128
10129    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10130        {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
10131        pc, elements);
10132}
10133
10134static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10135    const ggml_tensor * src0 = dst->src[0];
10136    const ggml_tensor * src1 = dst->src[1];
10137
10138    ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
10139        (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
10140        (uint32_t)src1->nb[1],
10141        (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
10142        (uint32_t)src1->ne[0],
10143        (uint32_t)src0->ne[0],
10144        (uint32_t)src0->ne[1],
10145        (uint32_t)dst->ne[1],
10146        (uint32_t)dst->ne[2],
10147    });
10148}
10149
10150static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc) {
10151    const ggml_tensor * x = dst->src[0];
10152    const ggml_tensor * g = dst->src[1];
10153    const ggml_tensor * gm = dst->src[2];
10154    const ggml_tensor * gv = dst->src[3];
10155    const ggml_tensor * p = dst->src[4];
10156
10157    GGML_ASSERT(x->type == GGML_TYPE_F32);
10158    GGML_ASSERT(g->type == GGML_TYPE_F32);
10159    GGML_ASSERT(gm->type == GGML_TYPE_F32);
10160    GGML_ASSERT(gv->type == GGML_TYPE_F32);
10161    GGML_ASSERT(p->type == GGML_TYPE_F32);
10162    GGML_ASSERT(dst->buffer != nullptr);
10163    GGML_ASSERT(ggml_is_contiguous(x));
10164    GGML_ASSERT(ggml_is_contiguous(g));
10165    GGML_ASSERT(ggml_is_contiguous(gm));
10166    GGML_ASSERT(ggml_is_contiguous(gv));
10167    GGML_ASSERT(ggml_is_contiguous(p));
10168    GGML_ASSERT(ggml_are_same_shape(x, g));
10169    GGML_ASSERT(ggml_are_same_shape(x, gm));
10170    GGML_ASSERT(ggml_are_same_shape(x, gv));
10171    GGML_ASSERT(ggml_nelements(p) == 7);
10172
10173    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
10174    GGML_ASSERT(pipeline != nullptr);
10175
10176    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10177
10178    vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
10179    vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);
10180    vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);
10181    vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);
10182    vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);
10183
10184    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
10185
10186    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10187        {x_buf, g_buf, gm_buf, gv_buf, p_buf},
10188        pc, elements);
10189}
10190
10191static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10192    const size_t n = ggml_nelements(dst->src[0]);
10193
10194    ggml_vk_op_f32_opt_step_adamw(
10195        ctx, subctx, dst,
10196        { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
10197    );
10198}
10199
10200static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
10201    const size_t n = ggml_nelements(dst->src[0]);
10202
10203    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
10204}
10205
10206static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10207    int * op_params = (int *)dst->op_params;
10208
10209    const uint32_t src0_type_size = ggml_type_size(src0->type);
10210    const uint32_t src1_type_size = ggml_type_size(src1->type);
10211    const uint32_t dst_type_size = ggml_type_size(dst->type);
10212
10213    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, {
10214        (uint32_t)ggml_nelements(dst),
10215        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10216        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10217        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
10218        0,
10219        0.0f, 0.0f, op_params[0],
10220    });
10221}
10222
10223static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10224    const uint32_t src0_type_size = ggml_type_size(src0->type);
10225    const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
10226
10227    GGML_TENSOR_UNARY_OP_LOCALS
10228
10229    float sf0 = (float)ne0 / ne00;
10230    float sf1 = (float)ne1 / ne01;
10231    float sf2 = (float)ne2 / ne02;
10232    float sf3 = (float)ne3 / ne03;
10233    float pixel_offset = 0.5f;
10234
10235    if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
10236        sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
10237        sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
10238        pixel_offset = 0.0f;
10239    }
10240
10241    ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
10242        (uint32_t)ggml_nelements(dst), 0, 0,
10243        (uint32_t)ne00, (uint32_t)ne01,
10244        (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size,
10245        (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
10246        sf0, sf1, sf2, sf3, pixel_offset
10247    });
10248}
10249
10250static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10251    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10252    p.param1 = ggml_get_op_params_f32(dst, 0);
10253    p.param2 = ggml_get_op_params_f32(dst, 1);
10254
10255    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p));
10256}
10257
10258static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10259    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst));
10260}
10261
10262static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10263    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
10264}
10265
10266static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10267    const uint32_t src0_type_size = ggml_type_size(src0->type);
10268    const uint32_t src1_type_size = ggml_type_size(src1->type);
10269    const uint32_t dst_type_size = ggml_type_size(dst->type);
10270
10271    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
10272        (uint32_t)ggml_nelements(src0),
10273        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10274        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10275        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
10276        0,
10277        0.0f, 0.0f, 0,
10278    });
10279}
10280
10281static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10282    VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
10283
10284    vk_op_push_constants pc = {
10285        (uint32_t)ggml_nelements(dst),
10286        1,
10287        ggml_get_op_params_f32(dst, 0),
10288        ggml_get_op_params_f32(dst, 2),
10289        0.0f, 0.0f,
10290    };
10291
10292    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
10293    GGML_ASSERT(pipeline != nullptr);
10294
10295    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10296    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
10297
10298    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
10299
10300    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
10301}
10302
10303static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10304    VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
10305
10306    vk_op_push_constants pc = {
10307        (uint32_t)ggml_nelements(dst),
10308        1,
10309        ggml_get_op_params_f32(dst, 0),
10310        0.0f,
10311        0.0f, 0.0f,
10312    };
10313
10314    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
10315    GGML_ASSERT(pipeline != nullptr);
10316
10317    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10318    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
10319
10320    std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
10321
10322    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
10323}
10324
10325static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10326    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
10327}
10328
10329static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10330    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
10331}
10332
10333static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10334    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
10335}
10336
10337static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10338    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10339    p.param1 = ggml_get_op_params_f32(dst, 0);
10340
10341    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
10342}
10343
10344static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10345    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
10346
10347    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
10348}
10349
10350static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10351    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10352    p.param1 = ggml_get_op_params_f32(dst, 0);
10353    p.param2 = ggml_get_op_params_f32(dst, 1);
10354
10355    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p));
10356}
10357
10358static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10359    vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
10360    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p));
10361}
10362
10363static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10364    const int32_t s0 = ggml_get_op_params_i32(dst, 0);
10365    const int32_t s1 = ggml_get_op_params_i32(dst, 1);
10366    const int32_t s2 = ggml_get_op_params_i32(dst, 2);
10367    const int32_t s3 = ggml_get_op_params_i32(dst, 3);
10368    const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
10369    const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
10370
10371    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10372    memcpy(&p.param1, &s01_packed, sizeof(float));
10373    memcpy(&p.param2, &s23_packed, sizeof(float));
10374
10375    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p));
10376}
10377
10378static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10379    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
10380    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p));
10381}
10382
10383static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10384    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
10385    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p));
10386}
10387
10388static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10389    uint32_t ne = (uint32_t)ggml_nelements(src0);
10390    if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
10391        // Convert from number of logical elements to 2- or 4-byte units.
10392        ne /= ggml_blck_size(src0->type);
10393        if ((ggml_type_size(src0->type) % 4) == 0) {
10394            ne *= ggml_type_size(src0->type) / 4;
10395        } else {
10396            ne *= ggml_type_size(src0->type) / 2;
10397        }
10398    }
10399
10400    vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
10401    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p));
10402}
10403
10404static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10405    const uint32_t src0_type_size = ggml_type_size(src0->type);
10406    const uint32_t src1_type_size = ggml_type_size(src1->type);
10407    const uint32_t dst_type_size = ggml_type_size(dst->type);
10408
10409    // Skip empty skip_rows operations. For most ops the empty check at the start
10410    // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
10411    // with empty srcs.
10412    if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
10413        return;
10414    }
10415
10416    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, {
10417        (uint32_t)ggml_nelements(src0),
10418        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10419        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10420        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
10421        0,
10422        0.0f, 0.0f, 0,
10423    });
10424}
10425
10426static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10427    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10428}
10429
10430static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10431    float * op_params = (float *)dst->op_params;
10432
10433    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10434}
10435
10436static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10437    const int * int_op_params = (const int *)dst->op_params;
10438    const float * float_op_params = (const float *)dst->op_params;
10439
10440    const uint32_t num_groups = int_op_params[0];
10441    const float eps = float_op_params[1];
10442    const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
10443
10444    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
10445}
10446
10447static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
10448    const uint32_t ne = (uint32_t)node->ne[0];
10449    const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
10450    const uint32_t num_partials = CEIL_DIV(ne, denom);
10451    return num_partials;
10452}
10453
10454static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
10455    const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
10456    const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
10457    return num_bytes;
10458}
10459
10460static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
10461    const int n_dims        = ((const int32_t *) dst->op_params)[1];
10462    const int mode          = ((const int32_t *) dst->op_params)[2];
10463    // const int n_ctx         = ((const int32_t *) dst->op_params)[3];
10464    const int n_ctx_orig    = ((const int32_t *) dst->op_params)[4];
10465    const float freq_base   = ((const float *)   dst->op_params)[5];
10466    const float freq_scale  = ((const float *)   dst->op_params)[6];
10467    const float ext_factor  = ((const float *)   dst->op_params)[7];
10468    const float attn_factor = ((const float *)   dst->op_params)[8];
10469    const float beta_fast   = ((const float *)   dst->op_params)[9];
10470    const float beta_slow   = ((const float *)   dst->op_params)[10];
10471    int sections[4] {};
10472    if (mode & GGML_ROPE_TYPE_MROPE) {
10473        memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);
10474    }
10475
10476    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
10477
10478    float corr_dims[2];
10479    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
10480
10481    const float theta_scale = powf(freq_base, -2.0f/n_dims);
10482
10483    uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
10484    uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
10485    uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
10486
10487    uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
10488    uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
10489    uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
10490
10491    vk_op_rope_push_constants rope {
10492        (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
10493        freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
10494        { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
10495
10496        (uint32_t)src0->ne[0],
10497        (uint32_t)src0->ne[1],
10498        (uint32_t)src0->ne[2],
10499        nb01, nb02, nb03,
10500        nb11, nb12, nb13,
10501    };
10502
10503    return rope;
10504}
10505
10506static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
10507    ggml_tensor * dst;
10508    const ggml_tensor * src0;
10509    const ggml_tensor * src1;
10510
10511    if (ctx->num_additional_fused_ops > 0) {
10512        // fused rms_norm + mul
10513        ggml_tensor *mul = cgraph->nodes[node_idx + 1];
10514        ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
10515        dst = mul;
10516        src0 = cgraph->nodes[node_idx]->src[0];
10517        src1 = other_src;
10518    } else {
10519        dst = cgraph->nodes[node_idx];
10520        src0 = src1 = dst->src[0];
10521    }
10522
10523    const uint32_t src0_type_size = ggml_type_size(src0->type);
10524    const uint32_t src1_type_size = ggml_type_size(src1->type);
10525    const uint32_t dst_type_size = ggml_type_size(dst->type);
10526
10527    uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
10528
10529    vk_op_binary_push_constants bin {
10530        (uint32_t)ggml_nelements(src0),
10531        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10532        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10533        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
10534        0,
10535        op_params[0], 0.0f, (int32_t)param3,
10536    };
10537
10538    // more than one fused op means rms_norm+mul+rope
10539    if (ctx->num_additional_fused_ops > 1) {
10540        static constexpr uint32_t max_tensors = 7;
10541        const ggml_tensor *tensors[max_tensors] {};
10542
10543        ggml_tensor *rms = cgraph->nodes[node_idx + 0];
10544        ggml_tensor *mul = cgraph->nodes[node_idx + 1];
10545        ggml_tensor *rope = cgraph->nodes[node_idx + 2];
10546
10547        ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
10548
10549        bool do_set_rows = ctx->num_additional_fused_ops == 4;
10550
10551        tensors[0] = rms->src[0];
10552        tensors[1] = other_src;
10553        tensors[2] = mul;
10554        tensors[3] = rope->src[1]; // pos
10555        tensors[4] = rope->src[2]; // ff
10556        tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
10557        tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
10558        const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;
10559
10560        vk_op_rms_norm_mul_rope_push_constants pc;
10561        pc.bin = bin;
10562        pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);
10563
10564        vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
10565
10566        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10567
10568        ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
10569        vk_buffer buf[max_tensors];
10570        size_t offset[max_tensors];
10571        bool uma[max_tensors];
10572
10573        for (uint32_t i = 0; i < max_tensors; ++i) {
10574            if (!tensors[i]) {
10575                // If any remaining descriptors are unused, just point them at src[0]
10576                buf[i] = buf[0];
10577                offset[i] = 0;
10578                continue;
10579            }
10580            buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
10581            buf[i] = nullptr;
10582            offset[i] = 0;
10583            uma[i] = false;
10584
10585            if (ctx->device->uma) {
10586                ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
10587                uma[i] = buf[i] != nullptr;
10588            }
10589            if (!uma[i]) {
10590                buf[i] = buf_ctx[i]->dev_buffer;
10591                offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
10592            }
10593            GGML_ASSERT(buf[i] != nullptr);
10594        }
10595
10596        std::array<uint32_t, 3> elements;
10597        elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
10598
10599        static_assert(max_tensors == 7);
10600        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10601            {
10602                ggml_vk_subbuffer(ctx, buf[0], offset[0]),
10603                ggml_vk_subbuffer(ctx, buf[1], offset[1]),
10604                ggml_vk_subbuffer(ctx, buf[2], offset[2]),
10605                ggml_vk_subbuffer(ctx, buf[3], offset[3]),
10606                ggml_vk_subbuffer(ctx, buf[4], offset[4]),
10607                ggml_vk_subbuffer(ctx, buf[5], offset[5]),
10608                ggml_vk_subbuffer(ctx, buf[6], offset[6]),
10609            }, pc, elements);
10610    } else {
10611        ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));
10612    }
10613
10614    if (ctx->do_add_rms_partials_offset_calculation) {
10615        ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
10616        ctx->do_add_rms_partials = false;
10617        ctx->do_add_rms_partials_offset_calculation = false;
10618    }
10619}
10620
10621static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10622    float * op_params = (float *)dst->op_params;
10623    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10624}
10625
10626static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10627    float * op_params = (float *)dst->op_params;
10628    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10629}
10630
10631static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10632    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10633}
10634
10635static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10636    float * op_params = (float *)dst->op_params;
10637    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
10638        {
10639            (uint32_t)ggml_nelements(src0), 0,
10640            op_params[1], op_params[2], op_params[3], op_params[4]
10641        }
10642    );
10643}
10644
10645static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10646    const float * op_params_f = (const float *)dst->op_params;
10647
10648    const bool swapped = (bool)dst->op_params[1];
10649    const bool split = src1 != nullptr;
10650    const float alpha = op_params_f[2];
10651    const float limit = op_params_f[3];
10652
10653    GGML_ASSERT(ggml_is_contiguous(src0));
10654
10655    if (!split) {
10656        GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
10657    } else {
10658        GGML_ASSERT(src0->ne[0] == src1->ne[0]);
10659        GGML_ASSERT(src0->ne[0] == dst->ne[0]);
10660        GGML_ASSERT(src0->type == src1->type);
10661    }
10662
10663    const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
10664
10665    ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU,
10666        {
10667            (uint32_t)ggml_nelements(dst),
10668            (uint32_t)src0->ne[0],
10669            (uint32_t)dst->ne[0],
10670            mode,
10671            alpha,
10672            limit
10673        });
10674}
10675
10676static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10677    int32_t * op_params = (int32_t *)dst->op_params;
10678    ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
10679}
10680
10681static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
10682    float * op_params = (float *)dst->op_params;
10683
10684    float scale = op_params[0];
10685    float max_bias = op_params[1];
10686
10687    const uint32_t ncols =   (uint32_t)src0->ne[0];
10688    const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
10689    const uint32_t nrows_y = (uint32_t)src0->ne[1];
10690
10691    const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
10692    const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
10693    const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
10694    const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
10695    const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
10696
10697    const uint32_t n_head_kv   = src0->ne[2];
10698    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
10699
10700    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
10701    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
10702
10703    vk_op_soft_max_push_constants pc {
10704        ncols,
10705        src1 != nullptr ? nrows_y : (uint32_t)0,
10706        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
10707        ne12, ne13,
10708        nb11, nb12, nb13,
10709        scale, max_bias,
10710        m0, m1,
10711        n_head_log2,
10712        nrows_x,
10713        src2 != nullptr
10714    };
10715
10716    if (ncols <= 16384) {
10717        ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
10718    } else {
10719
10720        vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
10721        vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
10722        vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
10723        vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
10724
10725        uint32_t elems_per_wg = 128 * 4;
10726        uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
10727        size_t tmp_size = num_wgs * nrows_x * sizeof(float);
10728
10729        if (ctx->prealloc_size_x < tmp_size) {
10730            ctx->prealloc_size_x = tmp_size;
10731            ggml_vk_preallocate_buffers(ctx, subctx);
10732        }
10733        if (ctx->prealloc_size_y < tmp_size) {
10734            ctx->prealloc_size_y = tmp_size;
10735            ggml_vk_preallocate_buffers(ctx, subctx);
10736        }
10737        if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
10738            ggml_vk_sync_buffers(ctx, subctx);
10739        }
10740
10741        vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
10742        vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
10743
10744        std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
10745
10746        vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
10747        vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
10748        vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
10749
10750        ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
10751        ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
10752        ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
10753
10754        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10755        ggml_vk_sync_buffers(ctx, subctx);
10756        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10757        ggml_vk_sync_buffers(ctx, subctx);
10758        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10759
10760        ctx->prealloc_x_need_sync = true;
10761        ctx->prealloc_y_need_sync = true;
10762    }
10763}
10764
10765static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10766    float * op_params = (float *)dst->op_params;
10767    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
10768}
10769
10770static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
10771    topk_moe_mode mode = ctx->fused_topk_moe_mode;
10772    ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
10773    ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
10774    ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
10775    ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
10776                        (mode == TOPK_MOE_LATE_SOFTMAX) ?      cgraph->nodes[node_idx + 1] :
10777                                                               cgraph->nodes[node_idx + 3];
10778
10779    GGML_ASSERT(logits->type == GGML_TYPE_F32);
10780    GGML_ASSERT(bias->type == GGML_TYPE_F32);
10781    GGML_ASSERT(weights->type == GGML_TYPE_F32);
10782    GGML_ASSERT(ids->type == GGML_TYPE_I32);
10783
10784    const int n_experts = logits->ne[0];
10785    const int n_rows    = logits->ne[1];
10786    const int n_expert_used = weights->ne[1];
10787
10788    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
10789
10790    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
10791
10792    ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10793
10794    vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
10795    vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
10796    vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
10797    vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
10798
10799    vk_op_topk_moe_push_constants pc {};
10800    pc.n_rows = n_rows;
10801    pc.n_experts_push = n_experts;
10802    pc.n_expert_used = n_expert_used;
10803    pc.clamp_min = -std::numeric_limits<float>::infinity();
10804    pc.clamp_max = std::numeric_limits<float>::infinity();
10805    if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
10806        ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
10807        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
10808        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
10809        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
10810    }
10811    if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
10812        ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
10813        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
10814        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
10815        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
10816    }
10817
10818#define GATING_FUNC_SOFTMAX 0
10819#define GATING_FUNC_SIGMOID 1
10820#define GATING_FUNC_SOFTMAX_WEIGHT 2
10821
10822    pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
10823                     mode == TOPK_MOE_LATE_SOFTMAX ?      GATING_FUNC_SOFTMAX_WEIGHT :
10824                                                          GATING_FUNC_SOFTMAX;
10825    pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
10826    pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
10827    if (ctx->fused_topk_moe_scale) {
10828        GGML_ASSERT(weights->op == GGML_OP_SCALE);
10829        pc.output_scale = ggml_get_op_params_f32(weights, 0);
10830        pc.output_bias = ggml_get_op_params_f32(weights, 1);
10831    } else {
10832        pc.output_scale = 1.0f;
10833        pc.output_bias = 0.0f;
10834    }
10835
10836    GGML_ASSERT(n_expert_used <= n_experts);
10837
10838    const uint32_t rows_per_block = 4;
10839    std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
10840
10841    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
10842}
10843
10844static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
10845    ggml_tensor * dst = cgraph->nodes[node_idx];
10846    const ggml_tensor * src0 = dst->src[0];
10847    const ggml_tensor * src1 = dst->src[1];
10848    const ggml_tensor * src2 = dst->src[2];
10849    const ggml_tensor * src3 = nullptr;
10850    const int n_dims        = ((int32_t *) dst->op_params)[1];
10851    const int mode          = ((int32_t *) dst->op_params)[2];
10852    // const int n_ctx         = ((int32_t *) dst->op_params)[3];
10853    const int n_ctx_orig    = ((int32_t *) dst->op_params)[4];
10854    const float freq_base   = ((float *)   dst->op_params)[5];
10855    const float beta_fast   = ((float *)   dst->op_params)[9];
10856    const float beta_slow   = ((float *)   dst->op_params)[10];
10857    int sections[4] {};
10858    if (mode & GGML_ROPE_TYPE_MROPE) {
10859        memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
10860    }
10861
10862    float corr_dims[2];
10863    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
10864
10865    uint32_t set_rows_stride = 0;
10866    // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
10867    // and overrides the dst and sets src3=row_indices
10868    if (ctx->num_additional_fused_ops > 0) {
10869        set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);
10870        src3 = cgraph->nodes[node_idx + 2]->src[1];
10871        dst = cgraph->nodes[node_idx + 2];
10872    }
10873
10874    ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,
10875        ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
10876}
10877
10878static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10879    const uint32_t * op_params = (const uint32_t *)dst->op_params;
10880
10881    uint32_t ncols = src0->ne[0];
10882    uint32_t nrows = ggml_nrows(src0);
10883
10884    uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
10885    uint32_t ncolsp2 = 1 << ncols_pad_log2;
10886
10887    vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
10888
10889    // Pick the largest workgroup size <= ncolsp2
10890    uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
10891
10892    // Use the "small" argsort shader if the whole sort can be done by a single workgroup.
10893    bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
10894                     ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
10895
10896    vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
10897                                     : ctx->device->pipeline_argsort_large_f32[pipeline_idx];
10898
10899    vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
10900    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10901    vk_subbuffer subbuf1 = dst_buf;
10902
10903    // Reserve space for ivec2 per element, with rows padded to a power of two
10904    if (!use_small) {
10905        const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
10906
10907        if (ctx->prealloc_size_x < x_sz) {
10908            ctx->prealloc_size_x = x_sz;
10909            ggml_vk_preallocate_buffers(ctx, subctx);
10910        }
10911        if (ctx->prealloc_x_need_sync) {
10912            ggml_vk_sync_buffers(ctx, subctx);
10913        }
10914        subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
10915    }
10916
10917    std::array<uint32_t, 3> elements;
10918
10919    elements[0] = ncolsp2;
10920    elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
10921    elements[2] = 1;
10922
10923    // First dispatch initializes tmp_idx and does the first N passes where
10924    // there is only communication between threads in the same workgroup.
10925    {
10926        vk_op_argsort_push_constants pc2 = pc;
10927        pc2.outer_start = 0;
10928        pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
10929        pc2.inner_start = 0;
10930        pc2.inner_end = 100;
10931        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10932        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
10933    }
10934    if (!use_small) {
10935        ggml_vk_sync_buffers(ctx, subctx);
10936        // Loop over outer/inner passes, synchronizing between each pass.
10937        for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
10938            for (uint32_t inner = 0; inner < outer + 1; ++inner) {
10939                vk_op_argsort_push_constants pc2 = pc;
10940                pc2.outer_start = outer;
10941                pc2.outer_end = outer + 1;
10942                pc2.inner_start = inner;
10943                pc2.inner_end = inner + 1;
10944                // When the inner idx is large enough, there's only communication
10945                // within a workgroup. So the remaining inner iterations can all
10946                // run in the same dispatch.
10947                if (outer - inner < pipeline_idx) {
10948                    pc2.inner_end = 100;
10949                    inner = outer;
10950                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
10951                } else {
10952                    // Smaller workgroup empirically seems to perform better
10953                    pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
10954                }
10955                ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10956                ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
10957                ggml_vk_sync_buffers(ctx, subctx);
10958            }
10959        }
10960        ctx->prealloc_x_need_sync = true;
10961    }
10962}
10963
10964static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10965    uint32_t ncols = src0->ne[0];
10966    uint32_t nrows = ggml_nrows(src0);
10967    uint32_t k = dst->ne[0];
10968
10969    vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
10970
10971    if (ctx->prealloc_x_need_sync) {
10972        ggml_vk_sync_buffers(ctx, subctx);
10973    }
10974
10975    std::array<uint32_t, 3> elements;
10976    elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
10977    elements[2] = 1;
10978
10979    uint32_t num_elements = ncols;
10980
10981    // Each iteration reduces a workgroup's worth of elements down to the K
10982    // largest elements. Repeat until we have the top K elements.
10983    // Need to do at least one iteration to write out the results.
10984    bool done_one_iter = false;
10985    uint32_t dbl_buf_index = 0;
10986    size_t dbl_buf_size;
10987    while (num_elements > k || !done_one_iter) {
10988
10989        // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
10990        // But if K is larger, then we need a larger workgroup
10991        uint32_t max_pipeline = num_topk_pipelines - 1;
10992        uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
10993        max_pipeline = std::min(preferred_pipeline, max_pipeline);
10994        uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
10995        // require full subgroup
10996        min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
10997
10998        uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
10999        pipeline_idx = std::min(pipeline_idx, max_pipeline);
11000        pipeline_idx = std::max(pipeline_idx, min_pipeline);
11001
11002        if (num_elements > (1u << pipeline_idx)) {
11003            // If we could finish on this loop iteration (i.e. a single workgroup)
11004            // then do so. It's better than the overhead of another pass.
11005            for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
11006                if (num_elements <= (1u << i)) {
11007                    pipeline_idx = i;
11008                    break;
11009                }
11010            }
11011        }
11012
11013        vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
11014        // If the device doesn't support a pipeline this large, use smaller
11015        while (!pipeline) {
11016            pipeline_idx--;
11017            GGML_ASSERT(pipeline_idx >= min_pipeline);
11018            pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
11019        }
11020
11021        vk_op_topk_push_constants pc2 = pc;
11022        pc2.ncols_input = num_elements;
11023
11024        // Number of elements remaining after this pass
11025        uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
11026
11027        pc2.ncols_output = num_dst_elements;
11028
11029        if (!done_one_iter) {
11030            // Reserve space for ivec2 per element, double buffered
11031            // K per workgroup per row
11032            dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
11033            dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
11034            const size_t x_sz = dbl_buf_size * 2;
11035
11036            if (ctx->prealloc_size_x < x_sz) {
11037                ctx->prealloc_size_x = x_sz;
11038                ggml_vk_preallocate_buffers(ctx, subctx);
11039            }
11040        }
11041
11042        vk_subbuffer src_buf;
11043        vk_subbuffer dst_buf;
11044
11045        if (num_elements == ncols) {
11046            pc2.first_pass = 1;
11047            src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
11048        } else {
11049            src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
11050        }
11051        if (num_dst_elements == k) {
11052            pc2.last_pass = 1;
11053            dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
11054        } else {
11055            dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
11056        }
11057
11058        elements[0] = num_elements;
11059
11060        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
11061        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
11062        num_elements = num_dst_elements;
11063        dbl_buf_index ^= 1;
11064        if (num_elements > k) {
11065            ggml_vk_sync_buffers(ctx, subctx);
11066        }
11067        done_one_iter = true;
11068    }
11069    ctx->prealloc_x_need_sync = true;
11070}
11071
11072static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11073    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
11074    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
11075}
11076
11077static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11078    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
11079    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p);
11080}
11081
11082static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11083    vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
11084    p.weight = 1.0f / (float)src0->ne[0];
11085    ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
11086}
11087
11088static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11089    vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
11090    // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
11091    // For fewer, larger rows, use the multipass shader to spread each row across SMs.
11092    if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
11093        ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
11094        return;
11095    }
11096
11097    // First pass computes partial sums within a block, and stores the last partial
11098    // to the temp buffer. Second pass sums the block partials from the temp buffer
11099    // and adds that to the result of the first pass.
11100    vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
11101    vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
11102    GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
11103
11104    ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
11105    ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
11106
11107    std::array<uint32_t, 3> elements;
11108
11109    elements[0] = dst->ne[0];
11110    elements[1] = (uint32_t)ggml_nrows(dst);
11111    elements[2] = 1;
11112
11113    size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
11114
11115    if (ctx->prealloc_size_split_k < temp_size) {
11116        ctx->prealloc_size_split_k = temp_size;
11117        ggml_vk_preallocate_buffers(ctx, subctx);
11118    }
11119
11120    vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
11121    vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
11122    vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
11123
11124    if (ctx->prealloc_split_k_need_sync) {
11125        ggml_vk_sync_buffers(ctx, subctx);
11126    }
11127
11128    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
11129    ggml_vk_sync_buffers(ctx, subctx);
11130    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
11131
11132    ctx->prealloc_split_k_need_sync = true;
11133}
11134
11135static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11136    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
11137}
11138
11139static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11140    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
11141}
11142
11143static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11144    const uint32_t src0_type_size = ggml_type_size(src0->type);
11145    const uint32_t src1_type_size = ggml_type_size(src1->type);
11146    const uint32_t dst_type_size = ggml_type_size(dst->type);
11147
11148    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
11149        (uint32_t)ggml_nelements(src0),
11150        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
11151        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
11152        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
11153        0,
11154        0.0f, 0.0f, 0,
11155    });
11156}
11157
11158static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11159    const int32_t s0 = dst->op_params[0];
11160    const int32_t s1 = dst->op_params[1];
11161    const int32_t p0 = dst->op_params[2];
11162    const int32_t p1 = dst->op_params[3];
11163    const int32_t d0 = dst->op_params[4];
11164    const int32_t d1 = dst->op_params[5];
11165
11166    const bool is_2D = dst->op_params[6] == 1;
11167
11168    const uint32_t IC = src1->ne[is_2D ? 2 : 1];
11169    const uint32_t IH = is_2D ? src1->ne[1] : 1;
11170    const uint32_t IW =         src1->ne[0];
11171
11172    const uint32_t KH = is_2D ? src0->ne[1] : 1;
11173    const uint32_t KW =         src0->ne[0];
11174
11175    const uint32_t OH = is_2D ? dst->ne[2] : 1;
11176    const uint32_t OW =         dst->ne[1];
11177
11178    const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
11179    const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
11180
11181    const uint32_t pelements = OW * KW * KH;
11182    const uint32_t batch = src1->ne[is_2D ? 3 : 2];
11183
11184    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
11185    const vk_buffer d_buf = d_buf_ctx->dev_buffer;
11186
11187    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
11188
11189    ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, {
11190        dst_addr,
11191        batch_offset, offset_delta,
11192        IC, IW, IH, OW, OH, KW, KH,
11193        pelements,
11194        IC * KH * KW,
11195        s0, s1, p0, p1, d0, d1, batch * IC
11196    });
11197}
11198
11199static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11200    GGML_TENSOR_BINARY_OP_LOCALS
11201
11202    const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
11203    const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
11204    const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
11205    const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
11206    const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
11207    const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
11208    const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
11209    const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
11210    const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
11211    const int32_t IC = ((const int32_t *)(dst->op_params))[9];
11212
11213    const int64_t N  = ne13 / IC;
11214    const int64_t ID = ne12;
11215    const int64_t IH = ne11;
11216    const int64_t IW = ne10;
11217
11218    const int64_t KD = ne02;
11219    const int64_t KH = ne01;
11220    const int64_t KW = ne00;
11221
11222    const int64_t OD = ne3 / N;
11223    const int64_t OH = ne2;
11224    const int64_t OW = ne1;
11225
11226    const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
11227    const vk_buffer d_buf = d_buf_ctx->dev_buffer;
11228
11229    const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
11230
11231    vk_op_im2col_3d_push_constants pc {};
11232
11233    pc.dst_addr = dst_addr;
11234    pc.nb10 = nb10 / ggml_type_size(src1->type);
11235    pc.nb11 = nb11 / ggml_type_size(src1->type);
11236    pc.nb12 = nb12 / ggml_type_size(src1->type);
11237    pc.nb13 = nb13 / ggml_type_size(src1->type);
11238    pc.s0 = s0;
11239    pc.s1 = s1;
11240    pc.s2 = s2;
11241    pc.p0 = p0;
11242    pc.p1 = p1;
11243    pc.p2 = p2;
11244    pc.d0 = d0;
11245    pc.d1 = d1;
11246    pc.d2 = d2;
11247    pc.IW = IW;
11248    pc.IH = IH;
11249    pc.ID = ID;
11250    pc.IC = IC;
11251    pc.KW = KW;
11252    pc.OH = OH;
11253    pc.KD_KH_KW = KD*KH*KW;
11254    pc.KH_KW = KH*KW;
11255    pc.IC_KD_KH_KW = IC*KD*KH*KW;
11256    pc.N_OD_OH = N*OD*OH;
11257    pc.OD_OH = OD*OH;
11258    pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
11259    pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
11260    pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
11261
11262    ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc));
11263}
11264
11265static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11266    const uint32_t dim = dst->op_params[0];
11267    const uint32_t max_period = dst->op_params[1];
11268    const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
11269
11270    ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
11271        nb1, dim, max_period,
11272    });
11273}
11274
11275static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11276    // src0: (K, Cout, Cin, 1) -- kernel
11277    // src1: (L, Cin, 1, 1) -- input
11278    // dst: (*, Cout, 1, 1)
11279
11280    GGML_ASSERT(src0->type == GGML_TYPE_F32);
11281    GGML_ASSERT(src1->type == GGML_TYPE_F32);
11282    GGML_ASSERT( dst->type == GGML_TYPE_F32);
11283
11284    GGML_TENSOR_BINARY_OP_LOCALS
11285
11286    GGML_ASSERT(nb00 == sizeof(float));
11287    GGML_ASSERT(nb10 == sizeof(float));
11288
11289    const int32_t s0 = dst->op_params[0];
11290
11291    vk_op_conv_transpose_1d_push_constants p{};
11292    p.Cout = static_cast<uint32_t>(ne01);
11293    p.Cin = static_cast<uint32_t>(ne02);
11294    p.K = static_cast<uint32_t>(ne00);
11295    p.L = static_cast<uint32_t>(ne10);
11296    p.KL = static_cast<uint32_t>(ne0);
11297    p.nb01 = static_cast<uint32_t>(nb01 / nb00);
11298    p.nb02 = static_cast<uint32_t>(nb02 / nb00);
11299    p.nb11 = static_cast<uint32_t>(nb11 / nb10);
11300    p.nb1 = static_cast<uint32_t>(nb1 / nb0);
11301    p.s0 = static_cast<uint32_t>(s0);
11302
11303    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
11304}
11305
11306static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11307    uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
11308    const int32_t k1 = dst->op_params[1];
11309    const int32_t k0 = dst->op_params[2];
11310    const int32_t s1 = dst->op_params[3];
11311    const int32_t s0 = dst->op_params[4];
11312    const int32_t p1 = dst->op_params[5];
11313    const int32_t p0 = dst->op_params[6];
11314
11315    const uint32_t IH = src0->ne[1];
11316    const uint32_t IW = src0->ne[0];
11317
11318    const uint32_t N = dst->ne[3];
11319
11320    const uint32_t OC = dst->ne[2];
11321    const uint32_t OH = dst->ne[1];
11322    const uint32_t OW = dst->ne[0];
11323
11324    const uint32_t parallel_elements = N * OC * OH * OW;
11325
11326    ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
11327        IW, IH, OW, OH, OC,
11328        parallel_elements,
11329        op,
11330        k0, k1, s0, s1, p0, p1,
11331    });
11332}
11333
11334static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
11335                            const ggml_tensor * src1, ggml_tensor * dst) {
11336    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
11337    GGML_ASSERT(src1->type == GGML_TYPE_F32);
11338    GGML_ASSERT(dst->type == GGML_TYPE_F32);
11339
11340    GGML_TENSOR_BINARY_OP_LOCALS
11341    GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
11342    GGML_ASSERT(nb10 == sizeof(float));
11343    GGML_ASSERT(nb0 == sizeof(float));
11344
11345    bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
11346
11347    vk_op_conv2d_push_constants p{};
11348    p.Cout = static_cast<uint32_t>(!transpose ? ne03 : ne02);
11349    p.Cin  = static_cast<uint32_t>(!transpose ? ne02 : ne03);
11350    p.N    = static_cast<uint32_t>(ne13);
11351    GGML_ASSERT(p.Cout == ne2);
11352    GGML_ASSERT(p.Cin == ne12);
11353
11354    p.W  = static_cast<uint32_t>(ne10);
11355    p.H  = static_cast<uint32_t>(ne11);
11356    p.OW = static_cast<uint32_t>(ne0);
11357    p.OH = static_cast<uint32_t>(ne1);
11358
11359    p.nb01 = static_cast<uint32_t>(nb01 / nb00);
11360    p.nb02 = static_cast<uint32_t>(nb02 / nb00);
11361    p.nb03 = static_cast<uint32_t>(nb03 / nb00);
11362
11363    p.nb11 = static_cast<uint32_t>(nb11 / nb10);
11364    p.nb12 = static_cast<uint32_t>(nb12 / nb10);
11365    p.nb13 = static_cast<uint32_t>(nb13 / nb10);
11366
11367    p.nb1 = static_cast<uint32_t>(nb1 / nb0);
11368    p.nb2 = static_cast<uint32_t>(nb2 / nb0);
11369    p.nb3 = static_cast<uint32_t>(nb3 / nb0);
11370
11371    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
11372}
11373
11374static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11375    vk_op_conv2d_dw_push_constants p{};
11376    p.ne = ggml_nelements(dst);
11377    p.channels = dst->ne[2];
11378    p.batches = dst->ne[3];
11379    p.dst_w = dst->ne[0];
11380    p.dst_h = dst->ne[1];
11381    p.src_w = src1->ne[0];
11382    p.src_h = src1->ne[1];
11383    p.knl_w = src0->ne[0];
11384    p.knl_h = src0->ne[1];
11385    p.stride_x = dst->op_params[0];
11386    p.stride_y = dst->op_params[1];
11387    p.pad_x = dst->op_params[2];
11388    p.pad_y = dst->op_params[3];
11389    p.dilation_x = dst->op_params[4];
11390    p.dilation_y = dst->op_params[5];
11391
11392    GGML_ASSERT(src0->ne[3] == p.channels);
11393    GGML_ASSERT(src1->ne[3] == p.batches);
11394
11395    ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p));
11396}
11397
11398static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11399    const float * op_params = (const float *)dst->op_params;
11400    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
11401}
11402
11403#ifdef GGML_VULKAN_RUN_TESTS
11404static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
11405    if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
11406        return;
11407    }
11408    i0 = std::max(i0, 5);
11409    i1 = std::max(i1, 5);
11410    i2 = std::max(i2, 0);
11411    fprintf(stderr, "         ");
11412    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11413        fprintf(stderr, "%7d ", idx1);
11414    }
11415    fprintf(stderr, "\n");
11416    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
11417        fprintf(stderr, "%7d: ", idx0);
11418        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11419            if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
11420                float val;
11421                if (type == GGML_TYPE_F32) {
11422                    val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
11423                } else if (type == GGML_TYPE_F16) {
11424                    val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
11425                } else {
11426                    GGML_ABORT("fatal error");
11427                }
11428                fprintf(stderr, "% 7.2f ", val);
11429            } else {
11430                fprintf(stderr, "        ");
11431            }
11432        }
11433        fprintf(stderr, "\n");
11434    }
11435}
11436
11437template <typename X_TYPE, typename Y_TYPE>
11438static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
11439    VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
11440    const size_t x_ne = m * k * batch;
11441    const size_t y_ne = k * n * batch;
11442    const size_t d_ne = m * n * batch;
11443
11444    vk_pipeline p;
11445    std::string shname;
11446    if (shader_size == 0) {
11447        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11448            p = ctx->device->pipeline_matmul_f32->a_s;
11449            shname = "F32_ALIGNED_S";
11450        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11451            p = ctx->device->pipeline_matmul_f32_f16->a_s;
11452            shname = "F32_F16_ALIGNED_S";
11453        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11454            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s;
11455            shname = "F16_F32_ALIGNED_S";
11456        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11457            p = ctx->device->pipeline_matmul_f16.f32acc->a_s;
11458            shname = "F16_ALIGNED_S";
11459        } else {
11460            GGML_ABORT("fatal error");
11461        }
11462    } else if (shader_size == 1) {
11463        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11464            p = ctx->device->pipeline_matmul_f32->a_m;
11465            shname = "F32_ALIGNED_M";
11466        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11467            p = ctx->device->pipeline_matmul_f32_f16->a_m;
11468            shname = "F32_F16_ALIGNED_M";
11469        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11470            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m;
11471            shname = "F16_F32_ALIGNED_M";
11472        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11473            p = ctx->device->pipeline_matmul_f16.f32acc->a_m;
11474            shname = "F16_ALIGNED_M";
11475        } else {
11476            GGML_ABORT("fatal error");
11477        }
11478    } else if (shader_size == 2) {
11479        if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11480            p = ctx->device->pipeline_matmul_f32->a_l;
11481            shname = "F32_ALIGNED_L";
11482        } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11483            p = ctx->device->pipeline_matmul_f32_f16->a_l;
11484            shname = "F32_F16_ALIGNED_L";
11485        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11486            p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l;
11487            shname = "F16_F32_ALIGNED_L";
11488        } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11489            p = ctx->device->pipeline_matmul_f16.f32acc->a_l;
11490            shname = "F16_ALIGNED_L";
11491        } else {
11492            GGML_ABORT("fatal error");
11493        }
11494    } else {
11495        GGML_ASSERT(0);
11496    }
11497
11498    const size_t kpad = ggml_vk_align_size(k, p->align);
11499
11500    if (k != kpad) {
11501        if (shader_size == 0) {
11502            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11503                p = ctx->device->pipeline_matmul_f32->s;
11504                shname = "F32_S";
11505            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11506                p = ctx->device->pipeline_matmul_f32_f16->s;
11507                shname = "F32_F16_S";
11508            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11509                p = ctx->device->pipeline_matmul_f16_f32.f32acc->s;
11510                shname = "F16_F32_S";
11511            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11512                p = ctx->device->pipeline_matmul_f16.f32acc->s;
11513                shname = "F16_S";
11514            }
11515        } else if (shader_size == 1) {
11516            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11517                p = ctx->device->pipeline_matmul_f32->m;
11518                shname = "F32_M";
11519            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11520                p = ctx->device->pipeline_matmul_f32_f16->m;
11521                shname = "F32_F16_M";
11522            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11523                p = ctx->device->pipeline_matmul_f16_f32.f32acc->m;
11524                shname = "F16_F32_M";
11525            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11526                p = ctx->device->pipeline_matmul_f16.f32acc->m;
11527                shname = "F16_M";
11528            }
11529        } else if (shader_size == 2) {
11530            if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11531                p = ctx->device->pipeline_matmul_f32->l;
11532                shname = "F32_L";
11533            } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11534                p = ctx->device->pipeline_matmul_f32_f16->l;
11535                shname = "F32_F16_L";
11536            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11537                p = ctx->device->pipeline_matmul_f16_f32.f32acc->l;
11538                shname = "F16_F32_L";
11539            } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11540                p = ctx->device->pipeline_matmul_f16.f32acc->l;
11541                shname = "F16_L";
11542            }
11543        }
11544    }
11545
11546    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
11547    if (split_k > 1) {
11548        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
11549
11550        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
11551            // Resize buffer
11552            if (ctx->prealloc_split_k != nullptr) {
11553                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
11554            }
11555            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11556        }
11557    }
11558
11559    ggml_pipeline_allocate_descriptor_sets(ctx);
11560
11561    vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11562    vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11563    vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11564
11565    X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
11566    Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
11567    float* d = (float *) malloc(sizeof(float) * d_ne);
11568
11569    for (size_t i = 0; i < x_ne; i++) {
11570        if (std::is_same<float, X_TYPE>()) {
11571            x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
11572            // x[i] = 1.0f;
11573            // x[i] = i + 1;
11574            // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
11575        } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
11576            x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
11577            // x[i] = ggml_fp32_to_fp16(1.0f);
11578            // x[i] = ggml_fp32_to_fp16(i + 1);
11579            // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
11580        } else {
11581            GGML_ABORT("fatal error");
11582        }
11583    }
11584    for (size_t i = 0; i < y_ne; i++) {
11585        if (std::is_same<float, Y_TYPE>()) {
11586            y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
11587            // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
11588            // y[i] = i + 1;
11589        } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
11590            y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
11591            // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
11592            // y[i] = ggml_fp32_to_fp16(i + 1);
11593        } else {
11594            GGML_ABORT("fatal error");
11595        }
11596    }
11597
11598    ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
11599    ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
11600
11601    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11602    ggml_vk_ctx_begin(ctx->device, subctx);
11603    for (size_t i = 0; i < num_it; i++) {
11604        ggml_vk_matmul(
11605            ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
11606            m, n, k,
11607            k, k, m, k*m, k*n, m*n,
11608            split_k, batch, batch, batch, 1, 1, n
11609        );
11610    }
11611    ggml_vk_ctx_end(subctx);
11612
11613    auto begin = std::chrono::high_resolution_clock::now();
11614    ggml_vk_submit(subctx, ctx->fence);
11615    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
11616    ctx->device->device.resetFences({ ctx->fence });
11617    ggml_vk_queue_command_pools_cleanup(ctx->device);
11618
11619    auto end = std::chrono::high_resolution_clock::now();
11620    double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
11621
11622    // copy dst to host
11623    ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
11624
11625    float * d_chk = (float *) malloc(sizeof(float) * d_ne);
11626
11627    ggml_init_params iparams = {
11628        /*.mem_size   =*/ 1024*1024*1024,
11629        /*.mem_buffer =*/ NULL,
11630        /*.no_alloc   =*/ true,
11631    };
11632
11633    ggml_context * ggml_ctx = ggml_init(iparams);
11634
11635    ggml_type src0_type;
11636    ggml_type src1_type;
11637
11638    if (std::is_same<float, X_TYPE>()) {
11639        src0_type = GGML_TYPE_F32;
11640    } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
11641        src0_type = GGML_TYPE_F16;
11642    } else {
11643        GGML_ABORT("fatal error");
11644    }
11645    if (std::is_same<float, Y_TYPE>()) {
11646        src1_type = GGML_TYPE_F32;
11647    } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
11648        src1_type = GGML_TYPE_F16;
11649    } else {
11650        GGML_ABORT("fatal error");
11651    }
11652
11653    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
11654    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
11655    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
11656
11657    src0_ggml->data = x;
11658    src1_ggml->data = y;
11659    tensor_ggml->data = d_chk;
11660
11661    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11662    ggml_build_forward_expand(cgraph, tensor_ggml);
11663
11664    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
11665
11666    ggml_free(ggml_ctx);
11667
11668    double avg_err = 0.0;
11669    int first_err_n = -1;
11670    int first_err_m = -1;
11671    int first_err_b = -1;
11672
11673    for (size_t i = 0; i < m*n*batch; i++) {
11674        double err = std::fabs(d[i] - d_chk[i]);
11675        avg_err += err;
11676
11677        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
11678            first_err_b = i / (m * n);
11679            first_err_n = (i % (m * n)) / m;
11680            first_err_m = (i % (m * n)) % m;
11681        }
11682    }
11683
11684    avg_err /= m * n;
11685
11686    double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
11687
11688    std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
11689
11690    if (avg_err > 0.1 || std::isnan(avg_err)) {
11691        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
11692        std::cerr << "Actual result: " << std::endl << std::endl;
11693        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11694        std::cerr << "Expected result: " << std::endl << std::endl;
11695        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11696
11697        if (split_k > 1) {
11698            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
11699            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
11700
11701            std::cerr << "d_buf0: " << std::endl << std::endl;
11702            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11703
11704            std::cerr << "d_buf1: " << std::endl << std::endl;
11705            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11706
11707            std::cerr << "d_buf2: " << std::endl << std::endl;
11708            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11709
11710            std::cerr << "d_buf3: " << std::endl << std::endl;
11711            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11712
11713            free(split_k_buf);
11714        }
11715    }
11716
11717    free(d_chk);
11718
11719    ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
11720
11721    ggml_vk_destroy_buffer(d_X);
11722    ggml_vk_destroy_buffer(d_Y);
11723    ggml_vk_destroy_buffer(d_D);
11724
11725    free(x);
11726    free(y);
11727    free(d);
11728}
11729
11730static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
11731    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
11732        return;
11733    }
11734    i0 = std::max(i0, 5);
11735    i1 = std::max(i1, 5);
11736    i2 = std::max(i2, 0);
11737    i3 = std::max(i3, 0);
11738    fprintf(stderr, "         ");
11739    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11740        fprintf(stderr, "%7d ", idx1);
11741    }
11742    fprintf(stderr, "\n");
11743    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
11744        fprintf(stderr, "%7d: ", idx0);
11745        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11746            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
11747                float val;
11748                if (tensor->type == GGML_TYPE_F32) {
11749                    val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
11750                } else if (tensor->type == GGML_TYPE_F16) {
11751                    val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
11752                } else {
11753                    GGML_ABORT("fatal error");
11754                }
11755                fprintf(stderr, "% 7.2f ", val);
11756            } else {
11757                fprintf(stderr, "        ");
11758            }
11759        }
11760        fprintf(stderr, "\n");
11761    }
11762}
11763
11764static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
11765    ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
11766}
11767
11768static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
11769    if (quant == GGML_TYPE_F32) {
11770        memcpy(to, from, sizeof(float) * ne);
11771        return;
11772    }
11773
11774    const auto * tt = ggml_get_type_traits(quant);
11775
11776    ggml_to_float_t dequant_fn = tt->to_float;
11777
11778    dequant_fn(from, to, ne);
11779}
11780
11781static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
11782    VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
11783    const size_t x_sz = sizeof(float) * ne;
11784    const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
11785    const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
11786    float * x = (float *) malloc(x_sz);
11787    void * qx = malloc(qx_sz);
11788    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11789    vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11790    float * x_ref = (float *) malloc(x_sz);
11791    ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
11792
11793    for (size_t i = 0; i < ne; i++) {
11794        x[i] = rand() / (float)RAND_MAX;
11795    }
11796
11797    vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
11798
11799    ggml_vk_quantize_data(x, qx, ne, quant);
11800    ggml_vk_dequantize_data(qx, x_ref, ne, quant);
11801
11802    ggml_pipeline_request_descriptor_sets(ctx, p, 1);
11803
11804    ggml_pipeline_allocate_descriptor_sets(ctx);
11805
11806    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
11807
11808    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11809    ggml_vk_ctx_begin(ctx->device, subctx);
11810    const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
11811    ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
11812    ggml_vk_ctx_end(subctx);
11813
11814    auto begin = std::chrono::high_resolution_clock::now();
11815
11816    ggml_vk_submit(subctx, ctx->fence);
11817    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
11818    ctx->device->device.resetFences({ ctx->fence });
11819    ggml_vk_queue_command_pools_cleanup(ctx->device);
11820
11821    auto end = std::chrono::high_resolution_clock::now();
11822
11823    double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
11824    ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
11825
11826    int first_err = -1;
11827
11828    double avg_err = 0.0;
11829    for (size_t i = 0; i < ne; i++) {
11830        double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
11831        avg_err += error;
11832
11833        if (first_err < 0 && error > 0.05) {
11834            first_err = i;
11835        }
11836    }
11837
11838    avg_err /= ne;
11839
11840    std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
11841
11842    if (avg_err > 0.1) {
11843        std::cerr << "first_error = " << first_err << std::endl;
11844        std::cerr << "Actual result: " << std::endl << std::endl;
11845        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
11846            std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
11847        }
11848        std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
11849        for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
11850            std::cerr << x_ref[i] << ", ";
11851        }
11852        std::cerr << std::endl;
11853    }
11854
11855    ggml_vk_destroy_buffer(x_buf);
11856    ggml_vk_destroy_buffer(qx_buf);
11857
11858    free(x);
11859    free(qx);
11860    free(x_ref);
11861    free(x_chk);
11862}
11863
11864// This does not work without ggml q8_1 quantization support
11865//
11866// typedef uint16_t ggml_half;
11867// typedef uint32_t ggml_half2;
11868//
11869// #define QK8_1 32
11870// typedef struct {
11871//     union {
11872//         struct {
11873//             ggml_half d; // delta
11874//             ggml_half s; // d * sum(qs[i])
11875//         } GGML_COMMON_AGGR_S;
11876//         ggml_half2 ds;
11877//     } GGML_COMMON_AGGR_U;
11878//     int8_t qs[QK8_1]; // quants
11879// } block_q8_1;
11880//
11881// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
11882//     VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
11883//     GGML_ASSERT(quant == GGML_TYPE_Q8_1);
11884//
11885//     const size_t x_sz = sizeof(float) * ne;
11886//     const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
11887//     float * x = (float *) malloc(x_sz);
11888//     block_q8_1 * qx     = (block_q8_1 *)malloc(qx_sz);
11889//     block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
11890//     vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11891//     vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11892//
11893//     for (size_t i = 0; i < ne; i++) {
11894//         x[i] = rand() / (float)RAND_MAX;
11895//     }
11896//
11897//     vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
11898//
11899//     ggml_pipeline_request_descriptor_sets(ctx, p, 1);
11900//
11901//     ggml_pipeline_allocate_descriptor_sets(ctx);
11902//
11903//     ggml_vk_buffer_write(x_buf, 0, x, x_sz);
11904//
11905//     vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11906//     ggml_vk_ctx_begin(ctx->device, subctx);
11907//     ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
11908//     ggml_vk_ctx_end(subctx);
11909//
11910//     auto begin = std::chrono::high_resolution_clock::now();
11911//
11912//     ggml_vk_submit(subctx, ctx->fence);
11913//     VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
11914//     ctx->device->device.resetFences({ ctx->fence });
11915//     ggml_vk_queue_command_pools_cleanup(ctx->device);
11916//
11917//     auto end = std::chrono::high_resolution_clock::now();
11918//
11919//     double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
11920//     ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
11921//
11922//     ggml_vk_quantize_data(x, qx_res, ne, quant);
11923//
11924//     int first_err = -1;
11925//
11926//     for (size_t i = 0; i < ne / 32; i++) {
11927//         double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
11928//
11929//         if (first_err < 0 && error > 0.1) {
11930//             first_err = i;
11931//         }
11932//
11933//         error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
11934//
11935//         if (first_err < 0 && error > 0.1) {
11936//             first_err = i;
11937//         }
11938//
11939//         for (size_t j = 0; j < 32; j++) {
11940//             uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
11941//
11942//             if (first_err < 0 && error > 1) {
11943//                 first_err = i;
11944//             }
11945//         }
11946//     }
11947//
11948//     std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
11949//
11950//     if (first_err != -1) {
11951//         std::cerr << "first_error = " << first_err << std::endl;
11952//         std::cerr << "Actual result: " << std::endl << std::endl;
11953//         std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
11954//         for (size_t j = 0; j < 32; j++) {
11955//             std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
11956//         }
11957//         std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
11958//         std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
11959//         for (size_t j = 0; j < 32; j++) {
11960//             std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
11961//         }
11962//         std::cerr << std::endl;
11963//     }
11964//
11965//     ggml_vk_destroy_buffer(x_buf);
11966//     ggml_vk_destroy_buffer(qx_buf);
11967//
11968//     free(x);
11969//     free(qx);
11970//     free(qx_res);
11971// }
11972
11973static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
11974    VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
11975    const size_t x_ne = m * k * batch;
11976    const size_t y_ne = k * n * batch;
11977    const size_t d_ne = m * n * batch;
11978
11979    vk_matmul_pipeline2 * pipelines;
11980
11981    if (mmq) {
11982        pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
11983    } else {
11984        pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
11985    }
11986
11987    const bool fp16acc = ctx->device->fp16;
11988
11989    vk_pipeline p;
11990    std::string shname;
11991    if (shader_size == 0) {
11992        p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
11993        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
11994    } else if (shader_size == 1) {
11995        p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
11996        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
11997    } else if (shader_size == 2) {
11998        p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
11999        shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
12000    } else {
12001        GGML_ASSERT(0);
12002    }
12003
12004    const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
12005
12006    if (mmq || k != kpad) {
12007        if (shader_size == 0) {
12008            p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
12009            shname = std::string(ggml_type_name(quant)) + "_S";
12010        } else if (shader_size == 1) {
12011            p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
12012            shname = std::string(ggml_type_name(quant)) + "_M";
12013        } else if (shader_size == 2) {
12014            p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
12015            shname = std::string(ggml_type_name(quant)) + "_L";
12016        } else {
12017            GGML_ASSERT(0);
12018        }
12019    }
12020
12021    if (p == nullptr) {
12022        std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
12023        return;
12024    }
12025
12026    const size_t x_sz = sizeof(float) * x_ne;
12027    const size_t y_sz = sizeof(float) * y_ne;
12028    const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
12029    const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
12030    const size_t d_sz = sizeof(float) * d_ne;
12031    float * x = (float *) malloc(x_sz);
12032    float * y = (float *) malloc(y_sz);
12033    void * qx = malloc(qx_sz);
12034    vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12035    vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12036    vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12037    vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12038    float * d = (float *) malloc(d_sz);
12039    float * d_chk = (float *) malloc(d_sz);
12040
12041    for (size_t i = 0; i < x_ne; i++) {
12042        x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
12043        // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
12044        // x[i] = i % k;
12045    }
12046
12047    ggml_vk_quantize_data(x, qx, x_ne, quant);
12048
12049    for (size_t i = 0; i < y_ne; i++) {
12050        y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
12051        // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
12052        // y[i] = i % k;
12053    }
12054
12055    ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
12056    if (split_k > 1) {
12057        ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
12058
12059        if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
12060            // Resize buffer
12061            if (ctx->prealloc_split_k != nullptr) {
12062                ggml_vk_destroy_buffer(ctx->prealloc_split_k);
12063            }
12064            ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12065        }
12066    }
12067    if (mmq) {
12068        vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
12069        ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
12070    }
12071
12072    ggml_pipeline_allocate_descriptor_sets(ctx);
12073
12074    ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
12075    ggml_vk_buffer_write(y_buf, 0, y, y_sz);
12076
12077    vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
12078    ggml_vk_ctx_begin(ctx->device, subctx);
12079    if (mmq) {
12080        for (size_t i = 0; i < num_it; i++) {
12081            ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
12082            ggml_vk_matmul(
12083                ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
12084                m, n, k,
12085                k, k, m, k*m, k*n, m*n,
12086                split_k, batch, batch, batch, 1, 1, n
12087            );
12088        }
12089    } else {
12090        for (size_t i = 0; i < num_it; i++) {
12091            ggml_vk_matmul(
12092                ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
12093                m, n, k,
12094                k, k, m, k*m, k*n, m*n,
12095                split_k, batch, batch, batch, 1, 1, n
12096            );
12097        }
12098    }
12099    ggml_vk_ctx_end(subctx);
12100
12101    auto begin = std::chrono::high_resolution_clock::now();
12102
12103    ggml_vk_submit(subctx, ctx->fence);
12104    VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
12105    ctx->device->device.resetFences({ ctx->fence });
12106    ggml_vk_queue_command_pools_cleanup(ctx->device);
12107
12108    auto end = std::chrono::high_resolution_clock::now();
12109
12110    double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
12111    ggml_vk_buffer_read(d_buf, 0, d, d_sz);
12112
12113    ggml_init_params iparams = {
12114        /*.mem_size   =*/ 1024*1024*1024,
12115        /*.mem_buffer =*/ NULL,
12116        /*.no_alloc   =*/ true,
12117    };
12118
12119    ggml_context * ggml_ctx = ggml_init(iparams);
12120
12121    ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
12122    ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
12123    ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
12124
12125    src0_ggml->data = qx;
12126    src1_ggml->data = y;
12127    tensor_ggml->data = d_chk;
12128
12129    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
12130    ggml_build_forward_expand(cgraph, tensor_ggml);
12131
12132    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
12133
12134    ggml_free(ggml_ctx);
12135
12136    double avg_err = 0.0;
12137    int first_err_n = -1;
12138    int first_err_m = -1;
12139    int first_err_b = -1;
12140
12141    for (size_t i = 0; i < m*n*batch; i++) {
12142        double err = std::fabs(d[i] - d_chk[i]);
12143        avg_err += err;
12144
12145        if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
12146            first_err_b = i / (m * n);
12147            first_err_n = (i % (m * n)) / m;
12148            first_err_m = (i % (m * n)) % m;
12149        }
12150    }
12151
12152    avg_err /= m * n;
12153
12154    double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
12155
12156    std::cerr << "TEST dequant matmul " << shname;
12157    if (mmq) {
12158        std::cerr << " mmq";
12159    }
12160    std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
12161
12162    if (avg_err > 0.01 || std::isnan(avg_err)) {
12163        std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
12164        std::cerr << "Actual result: " << std::endl << std::endl;
12165        ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12166        std::cerr << std::endl;
12167        std::cerr << "Expected result: " << std::endl << std::endl;
12168        ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12169
12170        std::cerr << "src0: " << std::endl << std::endl;
12171        ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
12172        std::cerr << std::endl;
12173        std::cerr << "src1: " << std::endl << std::endl;
12174        ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
12175
12176        if (split_k > 1) {
12177            float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
12178            ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
12179
12180            std::cerr << "d_buf0: " << std::endl << std::endl;
12181            ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12182
12183            std::cerr << "d_buf1: " << std::endl << std::endl;
12184            ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12185
12186            std::cerr << "d_buf2: " << std::endl << std::endl;
12187            ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12188
12189            std::cerr << "d_buf3: " << std::endl << std::endl;
12190            ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12191
12192            free(split_k_buf);
12193        }
12194    }
12195
12196    ggml_vk_destroy_buffer(qx_buf);
12197    ggml_vk_destroy_buffer(y_buf);
12198    ggml_vk_destroy_buffer(qy_buf);
12199    ggml_vk_destroy_buffer(d_buf);
12200
12201    free(x);
12202    free(qx);
12203    free(y);
12204    free(d);
12205    free(d_chk);
12206}
12207#endif
12208
12209static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) {
12210#if defined(GGML_VULKAN_RUN_TESTS)
12211    const std::vector<size_t> vals {
12212        512, 512, 128,
12213        128, 512, 512,
12214        4096, 512, 4096,
12215        11008, 512, 4096,
12216        4096, 512, 11008,
12217        32000, 512, 4096,
12218        8, 8, 8,
12219        100, 46, 576,
12220        623, 111, 128,
12221        100, 46, 558,
12222        512, 1, 256,
12223        128, 110, 622,
12224        511, 511, 127,
12225        511, 511, 7,
12226        511, 511, 17,
12227        49, 49, 128,
12228        128, 49, 49,
12229        4096, 49, 4096,
12230    };
12231    const size_t num_it = 100;
12232
12233    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
12234    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
12235    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
12236
12237    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
12238    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
12239    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
12240
12241    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
12242    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
12243    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
12244
12245    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
12246    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
12247    ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
12248
12249    abort();
12250
12251    for (size_t i = 0; i < vals.size(); i += 3) {
12252        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
12253        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
12254        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
12255        std::cerr << '\n';
12256        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
12257        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
12258        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
12259        std::cerr << '\n';
12260        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
12261        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
12262        ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
12263        std::cerr << '\n' << std::endl;
12264
12265        if (vals[i + 2] % 32 == 0) {
12266            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
12267            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
12268            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
12269            std::cerr << '\n';
12270            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
12271            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
12272            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
12273            std::cerr << '\n';
12274            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
12275            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
12276            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
12277            std::cerr << '\n' << std::endl;
12278        }
12279
12280        if (vals[i + 2] % 256 == 0) {
12281            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
12282            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
12283            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
12284            std::cerr << '\n';
12285            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
12286            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
12287            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
12288            std::cerr << '\n';
12289            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
12290            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
12291            ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
12292            std::cerr << '\n' << std::endl;
12293        }
12294    }
12295
12296    GGML_ABORT("fatal error");
12297#endif
12298
12299    if (subctx) {
12300        // Submit and wait for any pending work before reallocating the buffers
12301        ggml_vk_ctx_end(subctx);
12302        ggml_vk_submit(subctx, {});
12303        ctx->submit_pending = true;
12304        ggml_vk_synchronize(ctx);
12305        GGML_ASSERT(ctx->compute_ctx.expired());
12306        ggml_vk_ctx_begin(ctx->device, subctx);
12307        ctx->compute_ctx = subctx;
12308    }
12309
12310    if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
12311        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
12312        // Resize buffer
12313        if (ctx->prealloc_x != nullptr) {
12314            ggml_vk_destroy_buffer(ctx->prealloc_x);
12315        }
12316        ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
12317    }
12318    if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
12319        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
12320        // Resize buffer
12321        if (ctx->prealloc_y != nullptr) {
12322            ggml_vk_destroy_buffer(ctx->prealloc_y);
12323        }
12324        ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
12325        ctx->prealloc_y_last_tensor_used = nullptr;
12326    }
12327    if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
12328        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
12329        // Resize buffer
12330        if (ctx->prealloc_split_k != nullptr) {
12331            ggml_vk_destroy_buffer(ctx->prealloc_split_k);
12332        }
12333        ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
12334    }
12335    if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
12336        VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
12337        // Resize buffer
12338        if (ctx->prealloc_add_rms_partials != nullptr) {
12339            ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
12340        }
12341        ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
12342    }
12343}
12344
12345static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
12346
12347// Returns true if node has enqueued work into the queue, false otherwise
12348// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
12349static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
12350    ggml_tensor * node = cgraph->nodes[node_idx];
12351    if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
12352        return false;
12353    }
12354    if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
12355        return false;
12356    }
12357
12358    VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
12359    ctx->semaphore_idx = 0;
12360
12361    ggml_tensor * src0 = node->src[0];
12362    ggml_tensor * src1 = node->src[1];
12363    ggml_tensor * src2 = node->src[2];
12364    ggml_tensor * src3 = node->src[3];
12365
12366    if (node->op == GGML_OP_ADD) {
12367        int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
12368        if (next_node_idx < cgraph->n_nodes &&
12369            cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
12370            cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
12371            ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
12372            ctx->device->add_rms_fusion) {
12373            uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
12374            ctx->do_add_rms_partials_offset_calculation = true;
12375            if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
12376                ctx->do_add_rms_partials = true;
12377            }
12378        }
12379    }
12380
12381    vk_context compute_ctx;
12382
12383    if (ctx->compute_ctx.expired()) {
12384        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
12385        ctx->compute_ctx = compute_ctx;
12386        ggml_vk_ctx_begin(ctx->device, compute_ctx);
12387    } else {
12388        compute_ctx = ctx->compute_ctx.lock();
12389    }
12390
12391    {
12392        // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
12393        // to synchronize them. This handles most "normal" synchronization when computing the graph, and when
12394        // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
12395        // outside of this logic. When a node uses one of the prealloc buffers for something like
12396        // dequantization or split_k, additional synchronization is needed between those passes.
12397        bool need_sync = false;
12398
12399        // Check whether "node" requires synchronization. The node requires synchronization if it
12400        // overlaps in memory with another unsynchronized node and at least one of them is a write.
12401        // Destination nodes are checked against both the written/read lists. Source nodes are only
12402        // checked against the written list. Two nodes overlap in memory if they come from the same
12403        // buffer and the tensor or view ranges overlap.
12404        auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {
12405            if (unsynced_nodes.size() == 0) {
12406                return false;
12407            }
12408            auto n_base = vk_tensor_offset(node) + node->view_offs;
12409            auto n_size = ggml_nbytes(node);
12410            ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
12411            vk_buffer a_buf = a_buf_ctx->dev_buffer;
12412            for (auto &other : unsynced_nodes) {
12413                ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
12414                vk_buffer o_buf = o_buf_ctx->dev_buffer;
12415                if (a_buf == o_buf) {
12416                    auto o_base = vk_tensor_offset(other) + other->view_offs;
12417                    auto o_size = ggml_nbytes(other);
12418
12419                    if ((o_base <= n_base && n_base < o_base + o_size) ||
12420                        (n_base <= o_base && o_base < n_base + n_size)) {
12421                        return true;
12422                    }
12423                }
12424            }
12425            return false;
12426        };
12427
12428        // For all fused ops, check if the destination node or any of the source
12429        // nodes require synchronization.
12430        for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
12431            const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
12432            // If the node actually writes to memory, then check if it needs to sync
12433            if (ctx->fused_ops_write_mask & (1 << i)) {
12434                if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
12435                    need_sync = true;
12436                    break;
12437                }
12438            }
12439            for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
12440                if (!cur_node->src[j]) {
12441                    continue;
12442                }
12443                if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
12444                    need_sync = true;
12445                    break;
12446                }
12447            }
12448        }
12449
12450        if (need_sync) {
12451            if (vk_enable_sync_logger) {
12452                std::cerr <<  "sync" << std::endl;
12453            }
12454            ctx->unsynced_nodes_written.clear();
12455            ctx->unsynced_nodes_read.clear();
12456            ggml_vk_sync_buffers(ctx, compute_ctx);
12457
12458            if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
12459                ctx->query_node_idx[ctx->query_idx] = node_idx;
12460                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
12461            }
12462        }
12463        // Add all fused nodes to the unsynchronized lists.
12464        for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
12465            const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
12466            // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
12467            if (ctx->fused_ops_write_mask & (1 << i)) {
12468                ctx->unsynced_nodes_written.push_back(cur_node);
12469            }
12470            for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
12471                if (!cur_node->src[j]) {
12472                    continue;
12473                }
12474                ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
12475            }
12476        }
12477    }
12478    if (vk_enable_sync_logger) {
12479        for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
12480            auto *n = cgraph->nodes[node_idx + i];
12481            std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " <<  n->name;
12482            if (n->op == GGML_OP_GLU) {
12483                std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
12484            }
12485            if (n->op == GGML_OP_ROPE) {
12486                const int mode = ((const int32_t *) n->op_params)[2];
12487                std::cerr << " rope mode: " << mode;
12488            }
12489            std::cerr << std::endl;
12490        }
12491    }
12492
12493    switch (node->op) {
12494    case GGML_OP_REPEAT:
12495        ggml_vk_repeat(ctx, compute_ctx, src0, node);
12496
12497        break;
12498    case GGML_OP_REPEAT_BACK:
12499        ggml_vk_repeat_back(ctx, compute_ctx, src0, node);
12500
12501        break;
12502    case GGML_OP_ACC:
12503        ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
12504
12505        break;
12506    case GGML_OP_GET_ROWS:
12507        ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
12508
12509        break;
12510    case GGML_OP_ADD:
12511        if (ctx->num_additional_fused_ops) {
12512            ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx);
12513        } else {
12514            ggml_vk_add(ctx, compute_ctx, src0, src1, node);
12515        }
12516        break;
12517    case GGML_OP_SUB:
12518        ggml_vk_sub(ctx, compute_ctx, src0, src1, node);
12519
12520        break;
12521    case GGML_OP_MUL:
12522        ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
12523
12524        break;
12525    case GGML_OP_DIV:
12526        ggml_vk_div(ctx, compute_ctx, src0, src1, node);
12527
12528        break;
12529    case GGML_OP_ADD_ID:
12530        ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node);
12531
12532        break;
12533    case GGML_OP_CONCAT:
12534        ggml_vk_concat(ctx, compute_ctx, src0, src1, node);
12535
12536        break;
12537    case GGML_OP_UPSCALE:
12538        ggml_vk_upscale(ctx, compute_ctx, src0, node);
12539
12540        break;
12541    case GGML_OP_ADD1:
12542        ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
12543
12544        break;
12545    case GGML_OP_ARANGE:
12546        ggml_vk_arange(ctx, compute_ctx, node);
12547
12548        break;
12549    case GGML_OP_FILL:
12550        ggml_vk_fill(ctx, compute_ctx, node);
12551
12552        break;
12553    case GGML_OP_SCALE:
12554        ggml_vk_scale(ctx, compute_ctx, src0, node);
12555
12556        break;
12557    case GGML_OP_SQR:
12558        ggml_vk_sqr(ctx, compute_ctx, src0, node);
12559
12560        break;
12561    case GGML_OP_SQRT:
12562        ggml_vk_sqrt(ctx, compute_ctx, src0, node);
12563
12564        break;
12565    case GGML_OP_SIN:
12566        ggml_vk_sin(ctx, compute_ctx, src0, node);
12567
12568        break;
12569    case GGML_OP_COS:
12570        ggml_vk_cos(ctx, compute_ctx, src0, node);
12571
12572        break;
12573    case GGML_OP_LOG:
12574        ggml_vk_log(ctx, compute_ctx, src0, node);
12575
12576        break;
12577    case GGML_OP_TRI:
12578        ggml_vk_tri(ctx, compute_ctx, src0, node);
12579
12580        break;
12581    case GGML_OP_DIAG:
12582        ggml_vk_diag(ctx, compute_ctx, src0, node);
12583
12584        break;
12585    case GGML_OP_CLAMP:
12586        ggml_vk_clamp(ctx, compute_ctx, src0, node);
12587
12588        break;
12589    case GGML_OP_PAD:
12590        ggml_vk_pad(ctx, compute_ctx, src0, node);
12591
12592        break;
12593    case GGML_OP_ROLL:
12594        ggml_vk_roll(ctx, compute_ctx, src0, node);
12595
12596        break;
12597    case GGML_OP_CPY:
12598    case GGML_OP_CONT:
12599    case GGML_OP_DUP:
12600        ggml_vk_cpy(ctx, compute_ctx, src0, node);
12601
12602        break;
12603    case GGML_OP_SET_ROWS:
12604        ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node);
12605
12606        break;
12607    case GGML_OP_SILU_BACK:
12608        ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node);
12609
12610        break;
12611    case GGML_OP_NORM:
12612        ggml_vk_norm(ctx, compute_ctx, src0, node);
12613
12614        break;
12615    case GGML_OP_GROUP_NORM:
12616        ggml_vk_group_norm(ctx, compute_ctx, src0, node);
12617
12618        break;
12619    case GGML_OP_RMS_NORM:
12620        ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);
12621        break;
12622    case GGML_OP_RMS_NORM_BACK:
12623        ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
12624
12625        break;
12626    case GGML_OP_L2_NORM:
12627        ggml_vk_l2_norm(ctx, compute_ctx, src0, node);
12628
12629        break;
12630    case GGML_OP_UNARY:
12631        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12632            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12633            break;
12634        }
12635
12636        switch (ggml_get_unary_op(node)) {
12637        case GGML_UNARY_OP_EXP:
12638        case GGML_UNARY_OP_SILU:
12639        case GGML_UNARY_OP_GELU:
12640        case GGML_UNARY_OP_GELU_ERF:
12641        case GGML_UNARY_OP_GELU_QUICK:
12642        case GGML_UNARY_OP_RELU:
12643        case GGML_UNARY_OP_NEG:
12644        case GGML_UNARY_OP_TANH:
12645        case GGML_UNARY_OP_SIGMOID:
12646        case GGML_UNARY_OP_HARDSIGMOID:
12647        case GGML_UNARY_OP_HARDSWISH:
12648        case GGML_UNARY_OP_ABS:
12649        case GGML_UNARY_OP_SOFTPLUS:
12650        case GGML_UNARY_OP_STEP:
12651        case GGML_UNARY_OP_ROUND:
12652        case GGML_UNARY_OP_CEIL:
12653        case GGML_UNARY_OP_FLOOR:
12654        case GGML_UNARY_OP_TRUNC:
12655            ggml_vk_unary(ctx, compute_ctx, src0, node);
12656            break;
12657        case GGML_UNARY_OP_XIELU:
12658            ggml_vk_xielu(ctx, compute_ctx, src0, node);
12659            break;
12660        default:
12661            return false;
12662        }
12663        break;
12664    case GGML_OP_GLU:
12665        switch (ggml_get_glu_op(node)) {
12666        case GGML_GLU_OP_GEGLU:
12667        case GGML_GLU_OP_REGLU:
12668        case GGML_GLU_OP_SWIGLU:
12669        case GGML_GLU_OP_SWIGLU_OAI:
12670        case GGML_GLU_OP_GEGLU_ERF:
12671        case GGML_GLU_OP_GEGLU_QUICK:
12672            ggml_vk_glu(ctx, compute_ctx, src0, src1, node);
12673            break;
12674        default:
12675            return false;
12676        }
12677        break;
12678    case GGML_OP_DIAG_MASK_INF:
12679        ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);
12680
12681        break;
12682    case GGML_OP_SOFT_MAX:
12683        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12684            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12685        } else {
12686            ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
12687        }
12688
12689        break;
12690    case GGML_OP_SOFT_MAX_BACK:
12691        ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node);
12692
12693        break;
12694    case GGML_OP_ROPE:
12695        ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false);
12696
12697        break;
12698    case GGML_OP_ROPE_BACK:
12699        ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);
12700
12701        break;
12702    case GGML_OP_ARGSORT:
12703        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12704            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12705        } else {
12706            ggml_vk_argsort(ctx, compute_ctx, src0, node);
12707        }
12708
12709        break;
12710    case GGML_OP_TOP_K:
12711        ggml_vk_topk(ctx, compute_ctx, src0, node);
12712
12713        break;
12714    case GGML_OP_SUM:
12715        ggml_vk_sum(ctx, compute_ctx, src0, node);
12716
12717        break;
12718    case GGML_OP_SUM_ROWS:
12719        ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
12720
12721        break;
12722    case GGML_OP_CUMSUM:
12723        ggml_vk_cumsum(ctx, compute_ctx, src0, node);
12724
12725        break;
12726    case GGML_OP_MEAN:
12727        ggml_vk_mean(ctx, compute_ctx, src0, node);
12728
12729        break;
12730    case GGML_OP_ARGMAX:
12731        ggml_vk_argmax(ctx, compute_ctx, src0, node);
12732
12733        break;
12734    case GGML_OP_COUNT_EQUAL:
12735        ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
12736
12737        break;
12738    case GGML_OP_SOLVE_TRI:
12739        ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
12740
12741        break;
12742    case GGML_OP_IM2COL:
12743        ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
12744
12745        break;
12746    case GGML_OP_IM2COL_3D:
12747        ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node);
12748
12749        break;
12750    case GGML_OP_TIMESTEP_EMBEDDING:
12751        ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
12752
12753        break;
12754    case GGML_OP_CONV_TRANSPOSE_1D:
12755        ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
12756
12757        break;
12758    case GGML_OP_POOL_2D:
12759        ggml_vk_pool_2d(ctx, compute_ctx, src0, node);
12760
12761        break;
12762    case GGML_OP_CONV_2D:
12763    case GGML_OP_CONV_TRANSPOSE_2D:
12764        ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
12765
12766        break;
12767    case GGML_OP_CONV_2D_DW:
12768        ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
12769
12770        break;
12771    case GGML_OP_LEAKY_RELU:
12772        ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);
12773
12774        break;
12775    case GGML_OP_MUL_MAT:
12776        ggml_vk_mul_mat(ctx, compute_ctx, cgraph, node_idx);
12777
12778        break;
12779    case GGML_OP_MUL_MAT_ID:
12780        ggml_vk_mul_mat_id(ctx, compute_ctx, cgraph, node_idx);
12781
12782        break;
12783
12784    case GGML_OP_FLASH_ATTN_EXT:
12785        ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node);
12786
12787        break;
12788
12789    case GGML_OP_RWKV_WKV6:
12790        ggml_vk_rwkv_wkv6(ctx, compute_ctx, node);
12791
12792        break;
12793
12794    case GGML_OP_RWKV_WKV7:
12795        ggml_vk_rwkv_wkv7(ctx, compute_ctx, node);
12796
12797        break;
12798
12799    case GGML_OP_SSM_SCAN:
12800        ggml_vk_ssm_scan(ctx, compute_ctx, node);
12801
12802        break;
12803
12804    case GGML_OP_SSM_CONV:
12805        ggml_vk_ssm_conv(ctx, compute_ctx, node);
12806
12807        break;
12808
12809    case GGML_OP_OPT_STEP_ADAMW:
12810        ggml_vk_opt_step_adamw(ctx, compute_ctx, node);
12811
12812        break;
12813
12814    case GGML_OP_OPT_STEP_SGD:
12815        ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node);
12816
12817        break;
12818    default:
12819        return false;
12820    }
12821
12822    ctx->tensor_ctxs[node_idx] = compute_ctx;
12823
12824#if defined(GGML_VULKAN_CHECK_RESULTS)
12825    // Force context reset on each node so that each tensor ends up in its own context
12826    // and can be run and compared to its CPU equivalent separately
12827    last_node = true;
12828#endif
12829
12830    if (submit || last_node) {
12831        ggml_vk_ctx_end(compute_ctx);
12832
12833        // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward
12834        if (last_node) {
12835            compute_ctx->exit_tensor_idx = node_idx_begin;
12836        }
12837        else {
12838            compute_ctx->exit_tensor_idx = -1;
12839        }
12840
12841        ctx->compute_ctx.reset();
12842
12843        ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
12844    }
12845    return true;
12846}
12847
12848static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
12849    GGML_UNUSED(cgraph);
12850    GGML_UNUSED(tensor);
12851
12852    VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
12853
12854    vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
12855
12856    // Only run if ctx hasn't been submitted yet
12857    if (!subctx->seqs.empty()) {
12858#ifdef GGML_VULKAN_CHECK_RESULTS
12859        ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
12860#endif
12861
12862        // Do staging buffer copies
12863        for (auto& cpy : subctx->in_memcpys) {
12864            memcpy(cpy.dst, cpy.src, cpy.n);
12865        }
12866
12867        for (auto& mset : subctx->memsets) {
12868            memset(mset.dst, mset.val, mset.n);
12869        }
12870
12871        if (almost_ready && !ctx->almost_ready_fence_pending) {
12872            ggml_vk_submit(subctx, ctx->almost_ready_fence);
12873            ctx->almost_ready_fence_pending = true;
12874        } else {
12875            ggml_vk_submit(subctx, {});
12876        }
12877        ctx->submit_pending = true;
12878
12879#ifdef GGML_VULKAN_CHECK_RESULTS
12880        ggml_vk_synchronize(ctx);
12881        ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
12882#endif
12883    }
12884
12885    if (tensor_idx == subctx->exit_tensor_idx) {
12886        // Do staging buffer copies
12887        for (auto& cpy : subctx->out_memcpys) {
12888            memcpy(cpy.dst, cpy.src, cpy.n);
12889        }
12890        subctx->in_memcpys.clear();
12891        subctx->out_memcpys.clear();
12892        subctx->memsets.clear();
12893    }
12894}
12895
12896// Clean up after graph processing is done
12897static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
12898    VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
12899    ctx->prealloc_y_last_pipeline_used = {};
12900
12901    ctx->unsynced_nodes_written.clear();
12902    ctx->unsynced_nodes_read.clear();
12903    ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
12904
12905    ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
12906
12907    for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
12908        ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
12909    }
12910    ctx->gc.semaphores.clear();
12911
12912    for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
12913        ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
12914    }
12915    ctx->gc.tl_semaphores.clear();
12916    ctx->semaphore_idx = 0;
12917
12918    ctx->event_idx = 0;
12919
12920    for (auto& event : ctx->gc.events) {
12921        ctx->device->device.resetEvent(event);
12922    }
12923
12924    ctx->tensor_ctxs.clear();
12925    ctx->gc.contexts.clear();
12926    ctx->pipeline_descriptor_set_requirements = 0;
12927    ctx->descriptor_set_idx = 0;
12928}
12929
12930// Clean up on backend free
12931static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
12932    VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
12933    // discard any unsubmitted command buffers
12934    ctx->compute_ctx.reset();
12935    // wait for any pending command buffers to finish
12936    ggml_vk_synchronize(ctx);
12937
12938    ggml_vk_graph_cleanup(ctx);
12939
12940    ggml_vk_destroy_buffer(ctx->prealloc_x);
12941    ggml_vk_destroy_buffer(ctx->prealloc_y);
12942    ggml_vk_destroy_buffer(ctx->prealloc_split_k);
12943    ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
12944    ggml_vk_destroy_buffer(ctx->sync_staging);
12945
12946    ctx->prealloc_y_last_pipeline_used = nullptr;
12947
12948    ctx->prealloc_size_x = 0;
12949    ctx->prealloc_size_y = 0;
12950    ctx->prealloc_size_split_k = 0;
12951
12952    for (auto& event : ctx->gc.events) {
12953        ctx->device->device.destroyEvent(event);
12954    }
12955    ctx->gc.events.clear();
12956
12957    ctx->device->device.destroyFence(ctx->fence);
12958    ctx->device->device.destroyFence(ctx->almost_ready_fence);
12959
12960    for (auto& pool : ctx->descriptor_pools) {
12961        ctx->device->device.destroyDescriptorPool(pool);
12962    }
12963    ctx->descriptor_pools.clear();
12964    ctx->descriptor_sets.clear();
12965
12966    ctx->compute_cmd_pool.destroy(ctx->device->device);
12967    if (vk_perf_logger_enabled) {
12968        ctx->perf_logger->print_timings(true);
12969    }
12970}
12971
12972static int ggml_vk_get_device_count() {
12973    ggml_vk_instance_init();
12974
12975    return vk_instance.device_indices.size();
12976}
12977
12978static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
12979    ggml_vk_instance_init();
12980
12981    std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
12982
12983    vk::PhysicalDeviceProperties props;
12984    devices[device].getProperties(&props);
12985
12986    snprintf(description, description_size, "%s", props.deviceName.data());
12987}
12988
12989// backend interface
12990
12991#define UNUSED GGML_UNUSED
12992
12993// device backend
12994
12995static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
12996    return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
12997}
12998
12999static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
13000    VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
13001    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13002    ggml_vk_destroy_buffer(ctx->dev_buffer);
13003    delete ctx;
13004}
13005
13006static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
13007    return vk_ptr_base;
13008
13009    UNUSED(buffer);
13010}
13011
13012static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
13013    VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
13014    if (tensor->view_src != nullptr) {
13015        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
13016    }
13017    return GGML_STATUS_SUCCESS;
13018}
13019
13020static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
13021    VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
13022    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13023    vk_buffer buf = buf_ctx->dev_buffer;
13024
13025    uint32_t val32 = (uint32_t)value * 0x01010101;
13026    ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
13027}
13028
13029static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
13030    VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
13031    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13032    vk_buffer buf = buf_ctx->dev_buffer;
13033
13034    ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
13035}
13036
13037static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
13038    VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
13039    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13040
13041    vk_buffer buf = buf_ctx->dev_buffer;
13042
13043    ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
13044}
13045
13046static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
13047    if (ggml_backend_buffer_is_vk(src->buffer)) {
13048        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
13049        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
13050
13051        vk_buffer src_buf = src_buf_ctx->dev_buffer;
13052        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
13053
13054        ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
13055
13056        return true;
13057    }
13058    return false;
13059
13060    UNUSED(buffer);
13061}
13062
13063static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
13064    ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13065
13066    ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
13067}
13068
13069static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
13070    /* .free_buffer     = */ ggml_backend_vk_buffer_free_buffer,
13071    /* .get_base        = */ ggml_backend_vk_buffer_get_base,
13072    /* .init_tensor     = */ ggml_backend_vk_buffer_init_tensor,
13073    /* .memset_tensor   = */ ggml_backend_vk_buffer_memset_tensor,
13074    /* .set_tensor      = */ ggml_backend_vk_buffer_set_tensor,
13075    /* .get_tensor      = */ ggml_backend_vk_buffer_get_tensor,
13076    /* .cpy_tensor      = */ ggml_backend_vk_buffer_cpy_tensor,
13077    /* .clear           = */ ggml_backend_vk_buffer_clear,
13078    /* .reset           = */ NULL,
13079};
13080
13081// vk buffer type
13082static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
13083    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
13084
13085    return ctx->name.c_str();
13086}
13087
13088static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
13089    VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
13090    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
13091
13092    vk_buffer dev_buffer = nullptr;
13093    try {
13094        dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
13095    } catch (const vk::SystemError& e) {
13096        return nullptr;
13097    }
13098
13099    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
13100
13101    return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
13102}
13103
13104static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
13105    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
13106    return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
13107}
13108
13109static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
13110    ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
13111    return ctx->device->suballocation_block_size;
13112}
13113
13114static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
13115    return ggml_nbytes(tensor);
13116
13117    UNUSED(buft);
13118}
13119
13120ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
13121    ggml_vk_instance_init();
13122
13123    VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
13124
13125    vk_device dev = ggml_vk_get_device(dev_num);
13126
13127    return &dev->buffer_type;
13128}
13129
13130// host buffer type
13131
13132static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
13133    return GGML_VK_NAME "_Host";
13134
13135    UNUSED(buft);
13136}
13137
13138static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
13139    return GGML_VK_NAME "_Host";
13140
13141    UNUSED(buffer);
13142}
13143
13144static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
13145    VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
13146    ggml_vk_host_free(vk_instance.devices[0], buffer->context);
13147}
13148
13149static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
13150    VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
13151
13152    size += 32;  // Behave like the CPU buffer type
13153    void * ptr = nullptr;
13154    try {
13155        ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
13156    } catch (vk::SystemError& e) {
13157        GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
13158        // fallback to cpu buffer
13159        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
13160    }
13161
13162    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
13163    buffer->buft = buft;
13164    buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
13165
13166    return buffer;
13167
13168    UNUSED(buft);
13169}
13170
13171static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
13172    return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
13173
13174    UNUSED(buft);
13175}
13176
13177static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
13178    return vk_instance.devices[0]->suballocation_block_size;
13179
13180    UNUSED(buft);
13181}
13182
13183// Should be changed to return device-specific host buffer type
13184// but that probably requires changes in llama.cpp
13185ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
13186    static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
13187        /* .iface    = */ {
13188            /* .get_name         = */ ggml_backend_vk_host_buffer_type_name,
13189            /* .alloc_buffer     = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
13190            /* .get_alignment    = */ ggml_backend_vk_host_buffer_type_get_alignment,
13191            /* .get_max_size     = */ ggml_backend_vk_host_buffer_type_get_max_size,
13192            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
13193            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
13194        },
13195        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
13196        /* .context  = */ nullptr,
13197    };
13198
13199    // Make sure device 0 is initialized
13200    ggml_vk_instance_init();
13201    ggml_vk_get_device(0);
13202
13203    return &ggml_backend_vk_buffer_type_host;
13204}
13205
13206
13207// backend
13208
13209static const char * ggml_backend_vk_name(ggml_backend_t backend) {
13210    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13211
13212    return ctx->name.c_str();
13213}
13214
13215static void ggml_backend_vk_free(ggml_backend_t backend) {
13216    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13217    VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
13218
13219    ggml_vk_cleanup(ctx);
13220
13221    delete ctx;
13222    delete backend;
13223}
13224
13225static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
13226    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13227
13228    return &ctx->device->buffer_type;
13229}
13230
13231static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
13232    VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
13233    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13234    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13235
13236    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13237
13238    vk_context compute_ctx;
13239
13240    if (ctx->compute_ctx.expired()) {
13241        // Initialize new transfer context
13242        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13243        ctx->compute_ctx = compute_ctx;
13244        ggml_vk_ctx_begin(ctx->device, compute_ctx);
13245    } else {
13246        compute_ctx = ctx->compute_ctx.lock();
13247    }
13248
13249    vk_buffer buf = buf_ctx->dev_buffer;
13250
13251    auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13252
13253    bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size);
13254
13255    if (!ret) {
13256        ggml_vk_ensure_sync_staging_buffer(ctx, size);
13257        ggml_vk_sync_buffers(nullptr, compute_ctx);
13258
13259        vk::BufferCopy buffer_cpy;
13260        buffer_cpy.srcOffset = 0;
13261        buffer_cpy.dstOffset = dst_offset;
13262        buffer_cpy.size = size;
13263
13264        compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
13265        deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys);
13266        ggml_vk_synchronize(ctx);
13267    }
13268}
13269
13270static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
13271    VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
13272    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13273    GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13274
13275    ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13276
13277    vk_context compute_ctx;
13278
13279    if (ctx->compute_ctx.expired()) {
13280        // Initialize new transfer context
13281        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13282        ctx->compute_ctx = compute_ctx;
13283        ggml_vk_ctx_begin(ctx->device, compute_ctx);
13284    } else {
13285        compute_ctx = ctx->compute_ctx.lock();
13286    }
13287
13288    vk_buffer buf = buf_ctx->dev_buffer;
13289
13290    auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13291    bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
13292
13293    // If that failed, copy synchronously through a staging buffer
13294    if (!ret) {
13295        ggml_vk_ensure_sync_staging_buffer(ctx, size);
13296        ggml_vk_sync_buffers(nullptr, compute_ctx);
13297
13298        vk::BufferCopy buffer_cpy;
13299        buffer_cpy.srcOffset = src_offset;
13300        buffer_cpy.dstOffset = 0;
13301        buffer_cpy.size = size;
13302
13303        compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
13304        deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
13305        ggml_vk_synchronize(ctx);
13306    }
13307}
13308
13309static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
13310    VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
13311    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13312    if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
13313        ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
13314        ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
13315
13316        vk_context compute_ctx;
13317
13318        if (ctx->compute_ctx.expired()) {
13319            // Initialize new transfer context
13320            compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13321            ctx->compute_ctx = compute_ctx;
13322            ggml_vk_ctx_begin(ctx->device, compute_ctx);
13323        } else {
13324            compute_ctx = ctx->compute_ctx.lock();
13325        }
13326
13327        vk_buffer src_buf = src_buf_ctx->dev_buffer;
13328        vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
13329
13330        ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
13331        return true;
13332    }
13333
13334    return false;
13335}
13336
13337static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
13338    VK_LOG_DEBUG("ggml_vk_synchronize()");
13339
13340    bool do_transfer = !ctx->compute_ctx.expired();
13341
13342    vk_context compute_ctx;
13343    if (do_transfer) {
13344        compute_ctx = ctx->compute_ctx.lock();
13345
13346        ggml_vk_ctx_end(compute_ctx);
13347
13348        for (auto& cpy : compute_ctx->in_memcpys) {
13349            memcpy(cpy.dst, cpy.src, cpy.n);
13350        }
13351
13352        ggml_vk_submit(compute_ctx, {});
13353        ctx->submit_pending = true;
13354    }
13355
13356    if (ctx->submit_pending) {
13357        {
13358            std::lock_guard<std::mutex> guard(queue_mutex);
13359            ctx->device->compute_queue.queue.submit({}, ctx->fence);
13360        }
13361        ggml_vk_wait_for_fence(ctx);
13362        ctx->submit_pending = false;
13363    }
13364
13365    if (do_transfer) {
13366        for (auto& cpy : compute_ctx->out_memcpys) {
13367            memcpy(cpy.dst, cpy.src, cpy.n);
13368        }
13369        ctx->compute_ctx.reset();
13370    }
13371}
13372
13373static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
13374    VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
13375    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13376
13377    ggml_vk_synchronize(ctx);
13378
13379    ggml_vk_graph_cleanup(ctx);
13380}
13381
13382static bool ggml_vk_is_empty(ggml_tensor * node) {
13383    return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
13384}
13385
13386static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
13387    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
13388        return false;
13389    }
13390
13391    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
13392        // additional constraints specific to this fusion
13393        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
13394        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
13395
13396        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
13397        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
13398        // rms_norm only supports f32
13399        if (mul->src[0]->type != GGML_TYPE_F32 ||
13400            mul->src[1]->type != GGML_TYPE_F32 ||
13401            mul->type != GGML_TYPE_F32) {
13402            return false;
13403        }
13404        // if rms_norm is the B operand, then we don't handle broadcast
13405        if (rms_norm == mul->src[1] &&
13406            !ggml_are_same_shape(mul->src[0], rms_norm)) {
13407            return false;
13408        }
13409        // rms_norm shader assumes contiguous rows
13410        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
13411            return false;
13412        }
13413    }
13414    auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {
13415        const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
13416
13417        // mat-vec only
13418        if (ggml_nrows(mul) != 1) {
13419            return false;
13420        }
13421        // shaders assume the types match
13422        if (mul->type != bias->type) {
13423            return false;
13424        }
13425        // shaders reuse the D shape for bias
13426        if (!ggml_are_same_shape(mul, bias) ||
13427            !ggml_are_same_stride(mul, bias)) {
13428            return false;
13429        }
13430        // unaligned bias isn't handled
13431        if (get_misalign_bytes(ctx, bias) != 0) {
13432            return false;
13433        }
13434        return true;
13435    };
13436
13437    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
13438        // additional constraints specific to this fusion
13439        const ggml_tensor *mul = cgraph->nodes[node_idx];
13440        const ggml_tensor *add = cgraph->nodes[node_idx + 1];
13441
13442        if (!mm_add_ok(mul, add)) {
13443            return false;
13444        }
13445        if (ops.size() == 3) {
13446            if (ops.begin()[2] != GGML_OP_ADD) {
13447                return false;
13448            }
13449            if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {
13450                return false;
13451            }
13452        }
13453    }
13454
13455    auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {
13456        const ggml_tensor *scale = mul->src[1];
13457
13458        if (mmid != mul->src[0]) {
13459            return false;
13460        }
13461        // mat-vec only
13462        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
13463            return false;
13464        }
13465        // shaders assume the types match
13466        if (mmid->type != scale->type) {
13467            return false;
13468        }
13469        // shaders assume the bias is contiguous
13470        if (!ggml_is_contiguous(scale)) {
13471            return false;
13472        }
13473        // unaligned bias isn't handled
13474        if (get_misalign_bytes(ctx, scale) != 0) {
13475            return false;
13476        }
13477        // shader only indexes by expert index
13478        if (scale->ne[0] != 1 ||
13479            scale->ne[1] != mul->ne[1] ||
13480            scale->ne[2] != 1 ||
13481            scale->ne[3] != 1) {
13482            return false;
13483        }
13484        return true;
13485    };
13486
13487    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
13488        // additional constraints specific to this fusion
13489        const ggml_tensor *mul = cgraph->nodes[node_idx];
13490        const ggml_tensor *add = cgraph->nodes[node_idx + 1];
13491        const ggml_tensor *bias = add->src[1];
13492
13493        if (mul != add->src[0]) {
13494            return false;
13495        }
13496        // mat-vec only
13497        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
13498            return false;
13499        }
13500        // shaders assume the types match
13501        if (mul->type != bias->type) {
13502            return false;
13503        }
13504        // shaders assume the bias is contiguous
13505        if (!ggml_is_contiguous(bias)) {
13506            return false;
13507        }
13508        // the ID tensor must be the same for mul_mat_id and add_id
13509        if (mul->src[2] != add->src[2]) {
13510            return false;
13511        }
13512        // unaligned bias isn't handled
13513        if (get_misalign_bytes(ctx, bias) != 0) {
13514            return false;
13515        }
13516
13517        if (ops.size() == 3) {
13518            if (ops.begin()[2] != GGML_OP_MUL) {
13519                return false;
13520            }
13521            const ggml_tensor *mul = cgraph->nodes[node_idx + 2];
13522            return mmid_mul_ok(add, mul);
13523        }
13524    }
13525
13526    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
13527        // additional constraints specific to this fusion
13528        const ggml_tensor *mmid = cgraph->nodes[node_idx];
13529        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
13530
13531        if (!mmid_mul_ok(mmid, mul)) {
13532            return false;
13533        }
13534    }
13535
13536    return true;
13537}
13538
13539static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13540                                      int node_idx, topk_moe_mode mode) {
13541
13542    const ggml_tensor * softmax;
13543    const ggml_tensor * weights;
13544    const ggml_tensor * get_rows;
13545    const ggml_tensor * argsort;
13546
13547    switch (mode) {
13548    case TOPK_MOE_EARLY_SOFTMAX_NORM:
13549        softmax = cgraph->nodes[node_idx + 0];
13550        weights = cgraph->nodes[node_idx + 9];
13551        get_rows = cgraph->nodes[node_idx + 4];
13552        argsort = cgraph->nodes[node_idx + 2];
13553        break;
13554    case TOPK_MOE_SIGMOID_NORM_BIAS:
13555        softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
13556        weights = cgraph->nodes[node_idx + 10];
13557        get_rows = cgraph->nodes[node_idx + 5];
13558        argsort = cgraph->nodes[node_idx + 3];
13559        if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
13560            return false;
13561        }
13562        // bias is expected to be 1D
13563        if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
13564            !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
13565            return false;
13566        }
13567        // sigmoid fusion seems to generate infinities on moltenvk
13568        if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
13569            return false;
13570        }
13571        break;
13572    case TOPK_MOE_EARLY_SOFTMAX:
13573        softmax = cgraph->nodes[node_idx + 0];
13574        weights = cgraph->nodes[node_idx + 4];
13575        get_rows = cgraph->nodes[node_idx + 4];
13576        argsort = cgraph->nodes[node_idx + 2];
13577        break;
13578    case TOPK_MOE_LATE_SOFTMAX:
13579        softmax = cgraph->nodes[node_idx + 4];
13580        weights = cgraph->nodes[node_idx + 5];
13581        get_rows = cgraph->nodes[node_idx + 2];
13582        argsort = cgraph->nodes[node_idx + 0];
13583        break;
13584    default:
13585        return false;
13586    }
13587
13588    ggml_tensor * probs = get_rows->src[0];
13589    if (probs->op != GGML_OP_RESHAPE) {
13590        return false;
13591    }
13592    probs = probs->src[0];
13593    ggml_tensor * selection_probs = argsort->src[0];
13594
13595    if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
13596        return false;
13597    }
13598
13599    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
13600        return false;
13601    }
13602
13603    if (softmax->op == GGML_OP_SOFT_MAX) {
13604        const float * op_params = (const float *)softmax->op_params;
13605
13606        float scale = op_params[0];
13607        float max_bias = op_params[1];
13608
13609        if (scale != 1.0f || max_bias != 0.0f) {
13610            return false;
13611        }
13612
13613        // don't fuse when masks or sinks are present
13614        if (softmax->src[1] || softmax->src[2]) {
13615            return false;
13616        }
13617    }
13618
13619    const int n_expert = softmax->ne[0];
13620    if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
13621        return false;
13622    }
13623
13624    if (!ctx->device->subgroup_arithmetic ||
13625        !ctx->device->subgroup_shuffle ||
13626        !ctx->device->subgroup_require_full_support ||
13627        ctx->device->disable_fusion) {
13628        return false;
13629    }
13630
13631    return true;
13632}
13633
13634static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13635                                           int node_idx) {
13636    GGML_UNUSED(ctx);
13637    const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
13638    const ggml_tensor *view = cgraph->nodes[node_idx + 1];
13639    const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
13640
13641    // ne3 not tested
13642    if (rope->src[0]->ne[3] != 1) {
13643        return false;
13644    }
13645
13646    if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
13647        return false;
13648    }
13649
13650    if (set_rows->src[1]->type != GGML_TYPE_I64) {
13651        return false;
13652    }
13653
13654    // The view should flatten two dims of rope into one dim
13655    if (!ggml_is_contiguous(view) ||
13656        view->ne[0] != rope->ne[0] * rope->ne[1]) {
13657        return false;
13658    }
13659
13660    // Only norm/neox/mrope shaders have the fusion code
13661    const int mode = ((const int32_t *) rope->op_params)[2];
13662    if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
13663        return false;
13664    }
13665
13666    return true;
13667}
13668
13669// Check whether the tensors overlap in memory but are not equal.
13670// Fusions can potenitally overwrite src tensors in ways that are not prevented
13671// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
13672// to overlap if they are exactly equal.
13673// XXX TODO this check is probably missing from several fusion optimizations.
13674static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
13675    ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
13676    vk_buffer a_buf = a_buf_ctx->dev_buffer;
13677    ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
13678    vk_buffer b_buf = b_buf_ctx->dev_buffer;
13679    if (a_buf == b_buf) {
13680        auto a_base = vk_tensor_offset(a) + a->view_offs;
13681        auto a_size = ggml_nbytes(a);
13682        auto b_base = vk_tensor_offset(b) + b->view_offs;
13683        auto b_size = ggml_nbytes(b);
13684
13685        if (a_base == b_base && a_size == b_size) {
13686            return false;
13687        }
13688
13689        if ((b_base <= a_base && a_base < b_base + b_size) ||
13690            (a_base <= b_base && b_base < a_base + a_size)) {
13691            return true;
13692        }
13693    }
13694    return false;
13695}
13696
13697static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13698                                               int node_idx) {
13699    GGML_UNUSED(ctx);
13700    const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
13701    const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
13702    const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
13703
13704    const int mode = ((const int32_t *) rope->op_params)[2];
13705
13706    // noncontig tensors aren't tested, and don't seem common in practice
13707    if (!ggml_is_contiguous(rms) ||
13708        !ggml_is_contiguous(mul) ||
13709        !ggml_is_contiguous(rope)) {
13710        return false;
13711    }
13712
13713    // only norm/neox are handled in the shader
13714    if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
13715        return false;
13716    }
13717
13718    // shared memory size for passing data from mul->rope
13719    if (mul->ne[0] > 1024) {
13720        return false;
13721    }
13722
13723    // must not overwrite srcs in a way that's not elementwise
13724    ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
13725    if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
13726        ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
13727        return false;
13728    }
13729
13730    // conditions for pipeline creation
13731    if (!(ctx->device->float_controls_rte_fp16 &&
13732        sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
13733        return false;
13734    }
13735
13736    return true;
13737}
13738
13739static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
13740
13741    const ggml_tensor *first_node = cgraph->nodes[node_idx];
13742    if (first_node->op != GGML_OP_ADD) {
13743        return 0;
13744    }
13745
13746    if (!ctx->device->multi_add) {
13747        return 0;
13748    }
13749
13750    int32_t num_adds = 1;
13751    while (node_idx + num_adds < cgraph->n_nodes &&
13752           cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
13753           num_adds < MAX_FUSED_ADDS) {
13754        num_adds++;
13755    }
13756
13757    // The shader currently requires same shapes (but different strides are allowed),
13758    // everything f32, and no misalignment
13759    for (int32_t i = 0; i < num_adds; ++i) {
13760        const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
13761        if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
13762            !ggml_are_same_shape(first_node, next_node->src[1]) ||
13763            next_node->type != GGML_TYPE_F32 ||
13764            next_node->src[0]->type != GGML_TYPE_F32 ||
13765            next_node->src[1]->type != GGML_TYPE_F32 ||
13766            get_misalign_bytes(ctx, next_node) ||
13767            get_misalign_bytes(ctx, next_node->src[0]) ||
13768            get_misalign_bytes(ctx, next_node->src[1])) {
13769            num_adds = i;
13770        }
13771    }
13772
13773    // Verify we can fuse these
13774    ggml_op adds[MAX_FUSED_ADDS];
13775    for (int32_t i = 0; i < num_adds; ++i) {
13776        adds[i] = GGML_OP_ADD;
13777    }
13778
13779    // decrease num_adds if they can't all be fused
13780    while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
13781        num_adds--;
13782    }
13783
13784    // a single add is not "fused", so just return zero
13785    if (num_adds == 1) {
13786        return 0;
13787    }
13788    return num_adds;
13789}
13790
13791static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
13792    VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
13793    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13794
13795    if (vk_instance.debug_utils_support) {
13796        vk::DebugUtilsLabelEXT dul = {};
13797        dul.pLabelName = "ggml_backend_vk_graph_compute";
13798        dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
13799        vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
13800    }
13801
13802    ctx->prealloc_size_add_rms_partials_offset = 0;
13803    ctx->do_add_rms_partials = false;
13804    ctx->do_add_rms_partials_offset_calculation = false;
13805
13806    int last_node = cgraph->n_nodes - 1;
13807
13808    // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
13809    while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
13810        last_node -= 1;
13811    }
13812
13813    // Reserve tensor context space for all nodes
13814    ctx->tensor_ctxs.resize(cgraph->n_nodes);
13815
13816    bool first_node_in_batch = true; // true if next node will be first node in a batch
13817    int submit_node_idx = 0; // index to first node in a batch
13818
13819    vk_context compute_ctx;
13820    if (vk_perf_logger_enabled) {
13821        // allocate/resize the query pool
13822        if (ctx->num_queries < cgraph->n_nodes + 1) {
13823            if (ctx->query_pool) {
13824                ctx->device->device.destroyQueryPool(ctx->query_pool);
13825            }
13826            vk::QueryPoolCreateInfo query_create_info;
13827            query_create_info.queryType = vk::QueryType::eTimestamp;
13828            query_create_info.queryCount = cgraph->n_nodes + 100;
13829            ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
13830            ctx->num_queries = query_create_info.queryCount;
13831            ctx->query_fusion_names.resize(ctx->num_queries);
13832            ctx->query_fusion_node_count.resize(ctx->num_queries);
13833            ctx->query_nodes.resize(ctx->num_queries);
13834            ctx->query_node_idx.resize(ctx->num_queries);
13835        }
13836
13837        ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
13838        std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
13839        std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);
13840        std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
13841        std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
13842
13843        GGML_ASSERT(ctx->compute_ctx.expired());
13844        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13845        ctx->compute_ctx = compute_ctx;
13846        ggml_vk_ctx_begin(ctx->device, compute_ctx);
13847        ctx->query_idx = 0;
13848        compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
13849    }
13850
13851    ctx->prealloc_y_last_pipeline_used = nullptr;
13852    ctx->prealloc_y_last_tensor_used = nullptr;
13853
13854    if (ctx->prealloc_size_add_rms_partials) {
13855        ggml_vk_preallocate_buffers(ctx, nullptr);
13856        if (ctx->compute_ctx.expired()) {
13857            compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13858            ctx->compute_ctx = compute_ctx;
13859            ggml_vk_ctx_begin(ctx->device, compute_ctx);
13860        } else {
13861            compute_ctx = ctx->compute_ctx.lock();
13862        }
13863        // initialize partial sums to zero.
13864        ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
13865        ggml_vk_sync_buffers(ctx, compute_ctx);
13866    }
13867
13868    // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
13869    // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
13870    // (and scaled down based on model size, so smaller models submit earlier).
13871    // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
13872    int nodes_per_submit = 100;
13873    int submitted_nodes = 0;
13874    int submit_count = 0;
13875    uint64_t mul_mat_bytes = 0;
13876    uint64_t total_mul_mat_bytes = 0;
13877    uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);
13878    for (int i = 0; i < cgraph->n_nodes; i++) {
13879        if (first_node_in_batch) {
13880            submit_node_idx = i;
13881        }
13882
13883        if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
13884            auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);
13885            mul_mat_bytes += bytes;
13886            total_mul_mat_bytes += bytes;
13887        }
13888
13889        ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
13890        ctx->fused_topk_moe_scale = false;
13891        const char *fusion_string {};
13892        if (!ctx->device->disable_fusion) {
13893            uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
13894            if (num_adds) {
13895                ctx->num_additional_fused_ops = num_adds - 1;
13896                fusion_string = "MULTI_ADD";
13897            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
13898                ctx->num_additional_fused_ops = 2;
13899                fusion_string = "MUL_MAT_ADD_ADD";
13900            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
13901                ctx->num_additional_fused_ops = 1;
13902                fusion_string = "MUL_MAT_ADD";
13903            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
13904                ctx->num_additional_fused_ops = 2;
13905                fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
13906            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
13907                ctx->num_additional_fused_ops = 1;
13908                fusion_string = "MUL_MAT_ID_ADD_ID";
13909            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
13910                ctx->num_additional_fused_ops = 1;
13911                fusion_string = "MUL_MAT_ID_MUL";
13912            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
13913                       ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
13914                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
13915                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
13916                ctx->num_additional_fused_ops = 4;
13917                fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
13918            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
13919                       ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
13920                ctx->num_additional_fused_ops = 2;
13921                fusion_string = "RMS_NORM_MUL_ROPE";
13922            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
13923                ctx->num_additional_fused_ops = 1;
13924                fusion_string = "RMS_NORM_MUL";
13925            } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
13926                       ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
13927                       ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
13928                ctx->num_additional_fused_ops = 2;
13929                fusion_string = "ROPE_VIEW_SET_ROWS";
13930            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
13931                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
13932                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
13933                ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
13934                // view of argsort writes to memory
13935                ctx->fused_ops_write_mask |= 1 << 3;
13936                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
13937                fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
13938            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
13939                       ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
13940                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
13941                ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
13942                // view of argsort writes to memory
13943                ctx->fused_ops_write_mask |= 1 << 4;
13944                ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
13945                fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
13946            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
13947                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
13948                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
13949                ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
13950                // view of argsort writes to memory
13951                ctx->fused_ops_write_mask |= 1 << 3;
13952                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
13953                fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
13954            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
13955                       ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
13956                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
13957                ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
13958                // view of argsort writes to memory
13959                ctx->fused_ops_write_mask |= 1 << 1;
13960                ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
13961                fusion_string = "TOPK_MOE_LATE_SOFTMAX";
13962            }
13963            if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
13964                // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
13965                if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
13966                    ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
13967                    ctx->fused_topk_moe_scale = true;
13968                    ctx->num_additional_fused_ops++;
13969                }
13970            }
13971        }
13972        ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
13973
13974        // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
13975        bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
13976        bool submit = (submitted_nodes >= nodes_per_submit) ||
13977                      (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
13978                      (i + ctx->num_additional_fused_ops >= last_node) ||
13979                      (almost_ready && !ctx->almost_ready_fence_pending);
13980
13981        bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
13982
13983        if (vk_perf_logger_enabled && enqueued) {
13984            if (ctx->compute_ctx.expired()) {
13985                compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13986                ctx->compute_ctx = compute_ctx;
13987                ggml_vk_ctx_begin(ctx->device, compute_ctx);
13988            } else {
13989                compute_ctx = ctx->compute_ctx.lock();
13990            }
13991            if (!vk_perf_logger_concurrent) {
13992                // track a single node/fusion for the current query
13993                ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
13994                ctx->query_fusion_names[ctx->query_idx] = fusion_string;
13995                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
13996            } else {
13997                // track a fusion string and number of fused ops for the current node_idx
13998                ctx->query_fusion_names[i] = fusion_string;
13999                ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;
14000            }
14001        }
14002
14003        if (enqueued) {
14004            ++submitted_nodes;
14005
14006#ifndef GGML_VULKAN_CHECK_RESULTS
14007            if (first_node_in_batch) {
14008                first_node_in_batch = false;
14009            }
14010#endif
14011        }
14012
14013        if (submit && enqueued) {
14014            first_node_in_batch = true;
14015            submitted_nodes = 0;
14016            mul_mat_bytes = 0;
14017            if (submit_count < 3) {
14018                mul_mat_bytes_per_submit *= 2;
14019            }
14020            submit_count++;
14021        }
14022        i += ctx->num_additional_fused_ops;
14023        ctx->num_additional_fused_ops = 0;
14024        ctx->fused_ops_write_mask = 0;
14025    }
14026
14027    ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
14028
14029    if (vk_perf_logger_enabled) {
14030        // End the command buffer and submit/wait
14031        GGML_ASSERT(!ctx->compute_ctx.expired());
14032        compute_ctx = ctx->compute_ctx.lock();
14033        ggml_vk_ctx_end(compute_ctx);
14034
14035        ggml_vk_submit(compute_ctx, ctx->device->fence);
14036        VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
14037        ctx->device->device.resetFences({ ctx->device->fence });
14038        ctx->compute_ctx.reset();
14039
14040        // Get the results and pass them to the logger
14041        std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
14042        VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
14043        if (!vk_perf_logger_concurrent) {
14044            // Log each op separately
14045            for (int i = 1; i < ctx->query_idx; i++) {
14046                auto node = ctx->query_nodes[i];
14047                auto name = ctx->query_fusion_names[i];
14048                ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
14049            }
14050        } else {
14051            // Log each group of nodes
14052            int prev_node_idx = 0;
14053            for (int i = 1; i < ctx->query_idx; i++) {
14054                auto cur_node_idx = ctx->query_node_idx[i];
14055                std::vector<ggml_tensor *> nodes;
14056                std::vector<const char *> names;
14057                for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {
14058                    if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {
14059                        continue;
14060                    }
14061                    nodes.push_back(cgraph->nodes[node_idx]);
14062                    names.push_back(ctx->query_fusion_names[node_idx]);
14063                    node_idx += ctx->query_fusion_node_count[node_idx];
14064                }
14065                prev_node_idx = cur_node_idx;
14066                ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
14067            }
14068        }
14069        ctx->perf_logger->print_timings();
14070    }
14071
14072    if (!ctx->device->support_async) {
14073        ggml_vk_synchronize(ctx);
14074    }
14075
14076    return GGML_STATUS_SUCCESS;
14077
14078    UNUSED(backend);
14079}
14080
14081// Sort the graph for improved parallelism.
14082static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)
14083{
14084    VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)");
14085    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14086
14087    if (ctx->device->disable_graph_optimize) {
14088        return;
14089    }
14090
14091    auto const &is_empty = [](ggml_tensor * node) -> bool {
14092        return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
14093    };
14094
14095    auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
14096        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
14097            if (dst->src[s] == src) {
14098                return true;
14099            }
14100        }
14101        // implicit dependency if they view the same tensor
14102        const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
14103        const ggml_tensor *src2 = src->view_src ? src->view_src : src;
14104        if (dst2 == src2) {
14105            return true;
14106        }
14107        return false;
14108    };
14109
14110    std::vector<ggml_tensor *> new_order;
14111    std::vector<bool> used(graph->n_nodes, false);
14112    std::set<ggml_tensor *> used_node_set;
14113
14114    int first_unused = 0;
14115    while (first_unused < graph->n_nodes) {
14116        std::vector<int> current_set;
14117
14118        // Check for fusion patterns and avoid reordering them
14119        auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
14120            if (start + (int)pattern.size() <= graph->n_nodes) {
14121                bool is_pattern = true;
14122                for (size_t j = 0; j < pattern.size(); ++j) {
14123                    if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
14124                        is_pattern = false;
14125                    }
14126                }
14127                return is_pattern;
14128            }
14129            return false;
14130        };
14131
14132        auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
14133            if (match_pattern(pattern, first_unused)) {
14134                for (size_t j = 0; j < pattern.size(); ++j) {
14135                    new_order.push_back(graph->nodes[first_unused + j]);
14136                    used_node_set.insert(graph->nodes[first_unused + j]);
14137                    used[first_unused + j] = true;
14138                }
14139                while (first_unused < graph->n_nodes && used[first_unused]) {
14140                    first_unused++;
14141                }
14142                return true;
14143            }
14144            return false;
14145        };
14146
14147        if (keep_pattern(topk_moe_early_softmax_norm)) {
14148            continue;
14149        }
14150        if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
14151            continue;
14152        }
14153        if (keep_pattern(topk_moe_early_softmax)) {
14154            continue;
14155        }
14156        if (keep_pattern(topk_moe_late_softmax)) {
14157            continue;
14158        }
14159
14160        // First, grab the next unused node.
14161        current_set.push_back(first_unused);
14162
14163        // Loop through the next N nodes. Grab any that don't depend on other nodes that
14164        // haven't already been run. Nodes that have already been run have used[i] set
14165        // to true. Allow nodes that depend on the previous node if it's a fusion pattern
14166        // that we support (e.g. RMS_NORM + MUL).
14167        // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
14168        // The goal is to not interleave real and view nodes in a way that breaks fusion.
14169        const int NUM_TO_CHECK = 20;
14170        for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
14171            if (used[j]) {
14172                continue;
14173            }
14174            if (is_empty(graph->nodes[j])) {
14175                continue;
14176            }
14177            // Don't pull forward nodes from fusion patterns
14178            if (match_pattern(topk_moe_early_softmax_norm, j) ||
14179                match_pattern(topk_moe_sigmoid_norm_bias, j) ||
14180                match_pattern(topk_moe_early_softmax, j) ||
14181                match_pattern(topk_moe_late_softmax, j)) {
14182                continue;
14183            }
14184            bool ok = true;
14185            for (int c = first_unused; c < j; ++c) {
14186                if (!used[c] &&
14187                    is_src_of(graph->nodes[j], graph->nodes[c]) &&
14188                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
14189                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
14190                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
14191                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
14192                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
14193                    ok = false;
14194                    break;
14195                }
14196            }
14197            if (ok) {
14198                current_set.push_back(j);
14199
14200                int rope_idx = j;
14201
14202                // When we've found RMS_NORM + MUL, try to find a ROPE that uses it
14203                if (j > 0 &&
14204                    graph->nodes[j]->op == GGML_OP_MUL &&
14205                    graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
14206                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
14207                        if (graph->nodes[k]->op == GGML_OP_ROPE &&
14208                            graph->nodes[k]->src[0] == graph->nodes[j] &&
14209                            // Check that other srcs are already valid
14210                            graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
14211                            (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
14212                            rope_idx = k;
14213                            current_set.push_back(rope_idx);
14214                            used[rope_idx] = true;
14215                            break;
14216                        }
14217                    }
14218                }
14219                // Look for ROPE + VIEW + SET_ROWS and make them consecutive
14220                if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
14221                    int view_idx = -1;
14222                    int set_rows_idx = -1;
14223                    for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {
14224                        if (view_idx == -1 &&
14225                            graph->nodes[k]->op == GGML_OP_VIEW &&
14226                            graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
14227                            view_idx = k;
14228                            continue;
14229                        }
14230                        if (view_idx != -1 &&
14231                            set_rows_idx == -1 &&
14232                            graph->nodes[k]->op == GGML_OP_SET_ROWS &&
14233                            graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
14234                            set_rows_idx = k;
14235                            break;
14236                        }
14237                    }
14238                    if (set_rows_idx != -1) {
14239                        current_set.push_back(view_idx);
14240                        current_set.push_back(set_rows_idx);
14241                        used[view_idx] = true;
14242                        used[set_rows_idx] = true;
14243                    }
14244                }
14245                // Look for MUL_MAT_ID + ADD_ID + MUL
14246                if (j > 0 &&
14247                    graph->nodes[j]->op == GGML_OP_ADD_ID &&
14248                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {
14249                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
14250                        if (graph->nodes[k]->op == GGML_OP_MUL &&
14251                            graph->nodes[k]->src[0] == graph->nodes[j] &&
14252                            // src1 must either be weights or already processed
14253                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
14254                            current_set.push_back(k);
14255                            used[k] = true;
14256                            break;
14257                        }
14258                    }
14259                }
14260                // Look for MUL_MAT + ADD + ADD
14261                if (j > 0 &&
14262                    graph->nodes[j]->op == GGML_OP_ADD &&
14263                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {
14264                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
14265                        if (graph->nodes[k]->op == GGML_OP_ADD &&
14266                            graph->nodes[k]->src[0] == graph->nodes[j] &&
14267                            // src1 must either be weights or already processed
14268                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
14269                            current_set.push_back(k);
14270                            used[k] = true;
14271                            break;
14272                        }
14273                    }
14274                }
14275            }
14276        }
14277        // Second pass grabs view nodes.
14278        // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
14279        if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
14280            for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
14281                if (used[j]) {
14282                    continue;
14283                }
14284                if (!is_empty(graph->nodes[j])) {
14285                    continue;
14286                }
14287                bool ok = true;
14288                for (int c = first_unused; c < j; ++c) {
14289                    bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
14290                    // skip views whose srcs haven't been processed.
14291                    if (!used[c] &&
14292                        is_src_of(graph->nodes[j], graph->nodes[c]) &&
14293                        !c_in_current_set) {
14294                        ok = false;
14295                        break;
14296                    }
14297                }
14298                if (ok) {
14299                    current_set.push_back(j);
14300                }
14301            }
14302        }
14303
14304        // Push the current set into new_order
14305        for (auto c : current_set) {
14306            new_order.push_back(graph->nodes[c]);
14307            used_node_set.insert(graph->nodes[c]);
14308            used[c] = true;
14309        }
14310        while (first_unused < graph->n_nodes && used[first_unused]) {
14311            first_unused++;
14312        }
14313    }
14314    // Replace the graph with the new order.
14315    for (int i = 0; i < graph->n_nodes; ++i) {
14316        graph->nodes[i] = new_order[i];
14317    }
14318}
14319
14320static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
14321    VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
14322    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14323    vk_event *vkev = (vk_event *)event->context;
14324
14325    vk_context compute_ctx;
14326
14327    if (ctx->compute_ctx.expired()) {
14328        // Initialize new transfer context
14329        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
14330        ctx->compute_ctx = compute_ctx;
14331        ggml_vk_ctx_begin(ctx->device, compute_ctx);
14332    } else {
14333        compute_ctx = ctx->compute_ctx.lock();
14334    }
14335
14336    // the backend interface doesn't have an explicit reset, so reset it here
14337    // before we record the command to set it
14338    ctx->device->device.resetEvent(vkev->event);
14339    ctx->device->device.resetFences({ vkev->fence });
14340
14341    ggml_vk_set_event(compute_ctx, vkev->event);
14342
14343    ggml_vk_ctx_end(compute_ctx);
14344
14345    ggml_vk_submit(compute_ctx, {vkev->fence});
14346    ctx->submit_pending = true;
14347    ctx->compute_ctx.reset();
14348}
14349
14350static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
14351    VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
14352    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14353    vk_event *vkev = (vk_event *)event->context;
14354
14355    vk_context compute_ctx;
14356
14357    if (ctx->compute_ctx.expired()) {
14358        // Initialize new transfer context
14359        compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
14360        ctx->compute_ctx = compute_ctx;
14361        ggml_vk_ctx_begin(ctx->device, compute_ctx);
14362    } else {
14363        compute_ctx = ctx->compute_ctx.lock();
14364    }
14365
14366    ggml_vk_wait_events(compute_ctx, {vkev->event});
14367    ggml_vk_ctx_end(compute_ctx);
14368    ctx->compute_ctx.reset();
14369}
14370
14371// TODO: enable async and synchronize
14372static ggml_backend_i ggml_backend_vk_interface = {
14373    /* .get_name                = */ ggml_backend_vk_name,
14374    /* .free                    = */ ggml_backend_vk_free,
14375    /* .set_tensor_async        = */ ggml_backend_vk_set_tensor_async,
14376    /* .get_tensor_async        = */ ggml_backend_vk_get_tensor_async,
14377    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
14378    /* .synchronize             = */ ggml_backend_vk_synchronize,
14379    /* .graph_plan_create       = */ NULL,
14380    /* .graph_plan_free         = */ NULL,
14381    /* .graph_plan_update       = */ NULL,
14382    /* .graph_plan_compute      = */ NULL,
14383    /* .graph_compute           = */ ggml_backend_vk_graph_compute,
14384    /* .event_record            = */ ggml_backend_vk_event_record,
14385    /* .event_wait              = */ ggml_backend_vk_event_wait,
14386    /* .graph_optimize          = */ ggml_vk_graph_optimize,
14387};
14388
14389static ggml_guid_t ggml_backend_vk_guid() {
14390    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
14391    return &guid;
14392}
14393
14394ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
14395    VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
14396
14397    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
14398    ggml_vk_init(ctx, dev_num);
14399
14400    ggml_backend_t vk_backend = new ggml_backend {
14401        /* .guid    = */ ggml_backend_vk_guid(),
14402        /* .iface   = */ ggml_backend_vk_interface,
14403        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
14404        /* .context = */ ctx,
14405    };
14406
14407    if (!ctx->device->support_async) {
14408        vk_backend->iface.get_tensor_async = nullptr;
14409    }
14410
14411    return vk_backend;
14412}
14413
14414bool ggml_backend_is_vk(ggml_backend_t backend) {
14415    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
14416}
14417
14418int ggml_backend_vk_get_device_count() {
14419    return ggml_vk_get_device_count();
14420}
14421
14422void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
14423    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
14424    int dev_idx = vk_instance.device_indices[device];
14425    ggml_vk_get_device_description(dev_idx, description, description_size);
14426}
14427
14428void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
14429    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
14430    GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
14431
14432    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
14433    vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
14434    vk::PhysicalDeviceMemoryProperties2 memprops = {};
14435    const bool membudget_supported = vk_instance.device_supports_membudget[device];
14436    const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
14437
14438    if (membudget_supported) {
14439        memprops.pNext = &budgetprops;
14440    }
14441    vkdev.getMemoryProperties2(&memprops);
14442
14443    *total = 0;
14444    *free = 0;
14445
14446    for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
14447        const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
14448
14449        if (is_integrated_gpu || (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal)) {
14450            *total += heap.size;
14451
14452            if (membudget_supported && i < budgetprops.heapUsage.size()) {
14453                *free += budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
14454            } else {
14455                *free += heap.size;
14456            }
14457        }
14458    }
14459}
14460
14461static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) {
14462    GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
14463
14464    vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
14465
14466    vk::PhysicalDeviceProperties2 props = {};
14467    device.getProperties2(&props);
14468
14469    return props.properties.deviceType;
14470}
14471
14472static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
14473    GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
14474
14475    vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
14476
14477    const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
14478
14479    bool ext_support = false;
14480
14481    for (const auto& properties : ext_props) {
14482        if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) {
14483            ext_support = true;
14484            break;
14485        }
14486    }
14487
14488    if (!ext_support) {
14489        return "";
14490    }
14491
14492    vk::PhysicalDeviceProperties2 props = {};
14493    vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
14494
14495    props.pNext = &pci_bus_info;
14496
14497    device.getProperties2(&props);
14498
14499    const uint32_t pci_domain = pci_bus_info.pciDomain;
14500    const uint32_t pci_bus = pci_bus_info.pciBus;
14501    const uint32_t pci_device = pci_bus_info.pciDevice;
14502    const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
14503
14504    char pci_bus_id[16] = {};
14505    snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
14506
14507    return std::string(pci_bus_id);
14508}
14509
14510//////////////////////////
14511
14512struct ggml_backend_vk_device_context {
14513    size_t device;
14514    std::string name;
14515    std::string description;
14516    bool is_integrated_gpu;
14517    std::string pci_bus_id;
14518    int op_offload_min_batch_size;
14519};
14520
14521static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
14522    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14523    return ctx->name.c_str();
14524}
14525
14526static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
14527    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14528    return ctx->description.c_str();
14529}
14530
14531static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
14532    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
14533    ggml_backend_vk_get_device_memory(ctx->device, free, total);
14534}
14535
14536static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
14537    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14538    return ggml_backend_vk_buffer_type(ctx->device);
14539}
14540
14541static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
14542    UNUSED(dev);
14543    return ggml_backend_vk_host_buffer_type();
14544}
14545
14546static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
14547    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14548
14549    return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU;
14550}
14551
14552static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
14553    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14554
14555    props->name        = ggml_backend_vk_device_get_name(dev);
14556    props->description = ggml_backend_vk_device_get_description(dev);
14557    props->type        = ggml_backend_vk_device_get_type(dev);
14558    props->device_id   = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
14559    ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
14560    props->caps = {
14561        /* .async                 = */ true,
14562        /* .host_buffer           = */ true,
14563        /* .buffer_from_host_ptr  = */ false,
14564        /* .events                = */ true,
14565    };
14566}
14567
14568static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
14569    UNUSED(params);
14570    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14571    return ggml_backend_vk_init(ctx->device);
14572}
14573
14574static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
14575    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14576    const vk_device& device = ggml_vk_get_device(ctx->device);
14577
14578    const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
14579                          device->shader_int64 && device->buffer_device_address;
14580
14581    auto const & tensor_size_supported = [&](size_t tensor_size) {
14582        if (tensor_size > device->max_buffer_size) {
14583            return false;
14584        }
14585        // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
14586        // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
14587        if (!uses_bda && !device->shader_64b_indexing) {
14588            if (tensor_size > device->properties.limits.maxStorageBufferRange) {
14589                return false;
14590            }
14591        }
14592        return true;
14593    };
14594    // reject any tensors larger than the max buffer size
14595    for (int i = 0; i < GGML_MAX_SRC; i++) {
14596        if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
14597            return false;
14598        }
14599    }
14600    if (!tensor_size_supported(ggml_nbytes(op))) {
14601        return false;
14602    }
14603
14604    switch (op->op) {
14605        case GGML_OP_UNARY:
14606            switch (ggml_get_unary_op(op)) {
14607                case GGML_UNARY_OP_EXP:
14608                case GGML_UNARY_OP_GELU:
14609                case GGML_UNARY_OP_GELU_ERF:
14610                case GGML_UNARY_OP_GELU_QUICK:
14611                case GGML_UNARY_OP_SILU:
14612                case GGML_UNARY_OP_RELU:
14613                case GGML_UNARY_OP_XIELU:
14614                case GGML_UNARY_OP_NEG:
14615                case GGML_UNARY_OP_TANH:
14616                case GGML_UNARY_OP_SIGMOID:
14617                case GGML_UNARY_OP_HARDSIGMOID:
14618                case GGML_UNARY_OP_HARDSWISH:
14619                case GGML_UNARY_OP_ABS:
14620                case GGML_UNARY_OP_SOFTPLUS:
14621                case GGML_UNARY_OP_STEP:
14622                case GGML_UNARY_OP_ROUND:
14623                case GGML_UNARY_OP_CEIL:
14624                case GGML_UNARY_OP_FLOOR:
14625                case GGML_UNARY_OP_TRUNC:
14626                    return ggml_is_contiguous(op->src[0]) &&
14627                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14628                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
14629                           (op->src[0]->type == op->type);
14630                default:
14631                    return false;
14632            }
14633        case GGML_OP_GLU:
14634            switch (ggml_get_glu_op(op)) {
14635                case GGML_GLU_OP_GEGLU:
14636                case GGML_GLU_OP_REGLU:
14637                case GGML_GLU_OP_SWIGLU:
14638                case GGML_GLU_OP_SWIGLU_OAI:
14639                case GGML_GLU_OP_GEGLU_ERF:
14640                case GGML_GLU_OP_GEGLU_QUICK:
14641                    return ggml_is_contiguous(op->src[0]) &&
14642                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14643                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
14644                           (op->src[0]->type == op->type);
14645                default:
14646                    return false;
14647            }
14648        case GGML_OP_MUL_MAT:
14649        case GGML_OP_MUL_MAT_ID:
14650            {
14651                ggml_type src0_type = op->src[0]->type;
14652                if (op->op == GGML_OP_MUL_MAT_ID) {
14653                    if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
14654                        // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
14655                        return false;
14656                    }
14657                }
14658                switch (src0_type) {
14659                    case GGML_TYPE_F32:
14660                    case GGML_TYPE_F16:
14661                    case GGML_TYPE_BF16:
14662                    case GGML_TYPE_Q4_0:
14663                    case GGML_TYPE_Q4_1:
14664                    case GGML_TYPE_Q5_0:
14665                    case GGML_TYPE_Q5_1:
14666                    case GGML_TYPE_Q8_0:
14667                    case GGML_TYPE_Q2_K:
14668                    case GGML_TYPE_Q3_K:
14669                    case GGML_TYPE_Q4_K:
14670                    case GGML_TYPE_Q5_K:
14671                    case GGML_TYPE_Q6_K:
14672                    case GGML_TYPE_IQ1_S:
14673                    case GGML_TYPE_IQ1_M:
14674                    case GGML_TYPE_IQ2_XXS:
14675                    case GGML_TYPE_IQ2_XS:
14676                    case GGML_TYPE_IQ2_S:
14677                    case GGML_TYPE_IQ3_XXS:
14678                    case GGML_TYPE_IQ3_S:
14679                    case GGML_TYPE_IQ4_XS:
14680                    case GGML_TYPE_IQ4_NL:
14681                    case GGML_TYPE_MXFP4:
14682                        break;
14683                    default:
14684                        return false;
14685                }
14686                struct ggml_tensor * a;
14687                struct ggml_tensor * b;
14688                if (op->op == GGML_OP_MUL_MAT) {
14689                    a = op->src[0];
14690                    b = op->src[1];
14691                } else {
14692                    a = op->src[2];
14693                    b = op->src[1];
14694                }
14695                if (a->ne[3] != b->ne[3]) {
14696                    return false;
14697                }
14698                if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
14699                    !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
14700                    return false;
14701                }
14702                if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
14703                    // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
14704                    // So don't support this combination for now.
14705                    return false;
14706                }
14707
14708                return true;
14709            }
14710        case GGML_OP_FLASH_ATTN_EXT:
14711            {
14712                bool coopmat2 = device->coopmat2;
14713                uint32_t HSK = op->src[1]->ne[0];
14714                uint32_t HSV = op->src[2]->ne[0];
14715                if ((HSK % 8) != 0 || (HSV % 8) != 0) {
14716                    return false;
14717                }
14718                if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
14719                    return false;
14720                }
14721                if (op->src[0]->type != GGML_TYPE_F32) {
14722                    return false;
14723                }
14724                if (op->type != GGML_TYPE_F32) {
14725                    return false;
14726                }
14727                if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
14728                    return false;
14729                }
14730                // It's straightforward to support different K/V dequant, but would
14731                // significantly increase the number of pipelines
14732                if (op->src[1]->type != op->src[2]->type) {
14733                    return false;
14734                }
14735                switch (op->src[1]->type) {
14736                case GGML_TYPE_F16:
14737                case GGML_TYPE_F32:
14738                case GGML_TYPE_Q4_0:
14739                case GGML_TYPE_Q8_0:
14740                    // supported in scalar and coopmat2 paths
14741                    break;
14742                case GGML_TYPE_Q4_1:
14743                case GGML_TYPE_Q5_0:
14744                case GGML_TYPE_Q5_1:
14745                // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
14746                //case GGML_TYPE_Q2_K:
14747                //case GGML_TYPE_Q3_K:
14748                //case GGML_TYPE_Q4_K:
14749                //case GGML_TYPE_Q5_K:
14750                //case GGML_TYPE_Q6_K:
14751                //case GGML_TYPE_IQ1_S:
14752                //case GGML_TYPE_IQ1_M:
14753                //case GGML_TYPE_IQ2_XXS:
14754                //case GGML_TYPE_IQ2_XS:
14755                //case GGML_TYPE_IQ2_S:
14756                //case GGML_TYPE_IQ3_XXS:
14757                //case GGML_TYPE_IQ3_S:
14758                //case GGML_TYPE_IQ4_XS:
14759                case GGML_TYPE_IQ4_NL:
14760                    // currently supported only in coopmat2 path
14761                    if (!coopmat2) {
14762                        return false;
14763                    }
14764                    break;
14765                default:
14766                    return false;
14767                }
14768                if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
14769                    // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
14770                    return false;
14771                }
14772                return true;
14773            }
14774        case GGML_OP_GET_ROWS:
14775            {
14776                switch (op->src[0]->type) {
14777                    case GGML_TYPE_F32:
14778                    case GGML_TYPE_F16:
14779                    case GGML_TYPE_BF16:
14780                    case GGML_TYPE_Q4_0:
14781                    case GGML_TYPE_Q4_1:
14782                    case GGML_TYPE_Q5_0:
14783                    case GGML_TYPE_Q5_1:
14784                    case GGML_TYPE_Q8_0:
14785                    case GGML_TYPE_Q2_K:
14786                    case GGML_TYPE_Q3_K:
14787                    case GGML_TYPE_Q4_K:
14788                    case GGML_TYPE_Q5_K:
14789                    case GGML_TYPE_Q6_K:
14790                    case GGML_TYPE_IQ1_S:
14791                    case GGML_TYPE_IQ1_M:
14792                    case GGML_TYPE_IQ2_XXS:
14793                    case GGML_TYPE_IQ2_XS:
14794                    case GGML_TYPE_IQ2_S:
14795                    case GGML_TYPE_IQ3_XXS:
14796                    case GGML_TYPE_IQ3_S:
14797                    case GGML_TYPE_IQ4_XS:
14798                    case GGML_TYPE_IQ4_NL:
14799                    case GGML_TYPE_MXFP4:
14800                    case GGML_TYPE_I32:
14801                        return true;
14802                    default:
14803                        return false;
14804                }
14805            }
14806        case GGML_OP_SET_ROWS:
14807            {
14808                switch (op->type) {
14809                    case GGML_TYPE_F32:
14810                    case GGML_TYPE_F16:
14811                    case GGML_TYPE_BF16:
14812                    case GGML_TYPE_Q4_0:
14813                    case GGML_TYPE_Q4_1:
14814                    case GGML_TYPE_Q5_0:
14815                    case GGML_TYPE_Q5_1:
14816                    case GGML_TYPE_Q8_0:
14817                    case GGML_TYPE_IQ4_NL:
14818                        return true;
14819                    default:
14820                        return false;
14821                }
14822            }
14823        case GGML_OP_CONT:
14824        case GGML_OP_CPY:
14825        case GGML_OP_DUP:
14826            {
14827                ggml_type src0_type = op->src[0]->type;
14828                ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
14829
14830                if (src0_type == GGML_TYPE_F32) {
14831                    switch (src1_type) {
14832                    case GGML_TYPE_F32:
14833                    case GGML_TYPE_F16:
14834                    case GGML_TYPE_BF16:
14835                    case GGML_TYPE_Q4_0:
14836                    case GGML_TYPE_Q4_1:
14837                    case GGML_TYPE_Q5_0:
14838                    case GGML_TYPE_Q5_1:
14839                    case GGML_TYPE_Q8_0:
14840                    case GGML_TYPE_IQ4_NL:
14841                        return true;
14842                    default:
14843                        break;
14844                    }
14845                }
14846                if (src1_type == GGML_TYPE_F32) {
14847                    switch (src0_type) {
14848                    case GGML_TYPE_F16:
14849                    case GGML_TYPE_Q4_0:
14850                    case GGML_TYPE_Q4_1:
14851                    case GGML_TYPE_Q5_0:
14852                    case GGML_TYPE_Q5_1:
14853                    case GGML_TYPE_Q8_0:
14854                    case GGML_TYPE_IQ4_NL:
14855                        return true;
14856                    default:
14857                        break;
14858                    }
14859                }
14860
14861                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
14862                    return true;
14863                }
14864
14865                if (
14866                    (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||
14867                    (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)
14868                ) {
14869                    return true;
14870                }
14871
14872                // We can handle copying from a type to the same type if it's
14873                // either not quantized or is quantized and contiguous.
14874                // We use f16 or f32 shaders to do the copy,
14875                // so the type/block size must be a multiple of 4.
14876                if (src0_type == src1_type &&
14877                    (!ggml_is_quantized(src0_type) || (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op))) &&
14878                    (ggml_type_size(src0_type) % 2) == 0) {
14879                    return true;
14880                }
14881                return false;
14882            }
14883        case GGML_OP_REPEAT:
14884            return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
14885        case GGML_OP_REPEAT_BACK:
14886            return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
14887        case GGML_OP_ROPE:
14888            return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
14889        case GGML_OP_ROPE_BACK:
14890        case GGML_OP_NONE:
14891        case GGML_OP_RESHAPE:
14892        case GGML_OP_VIEW:
14893        case GGML_OP_PERMUTE:
14894        case GGML_OP_TRANSPOSE:
14895        case GGML_OP_RMS_NORM:
14896            return true;
14897        case GGML_OP_NORM:
14898        case GGML_OP_GROUP_NORM:
14899        case GGML_OP_L2_NORM:
14900            return ggml_is_contiguous(op->src[0]);
14901        case GGML_OP_ADD:
14902        case GGML_OP_SUB:
14903        case GGML_OP_MUL:
14904        case GGML_OP_DIV:
14905            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14906                   (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
14907                   (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
14908        case GGML_OP_ADD_ID:
14909            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
14910                   op->type == GGML_TYPE_F32;
14911        case GGML_OP_SILU_BACK:
14912        case GGML_OP_RMS_NORM_BACK:
14913            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14914        case GGML_OP_SQR:
14915        case GGML_OP_SQRT:
14916        case GGML_OP_SIN:
14917        case GGML_OP_COS:
14918        case GGML_OP_CLAMP:
14919            return op->src[0]->type == GGML_TYPE_F32;
14920        case GGML_OP_LEAKY_RELU:
14921        case GGML_OP_OPT_STEP_ADAMW:
14922        case GGML_OP_OPT_STEP_SGD:
14923            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14924        case GGML_OP_LOG:
14925        case GGML_OP_TRI:
14926        case GGML_OP_DIAG:
14927            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14928                   op->type == op->src[0]->type;
14929        case GGML_OP_ARGSORT:
14930            {
14931                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
14932                    return false;
14933                }
14934                // pipeline_argsort_large_f32 requires vulkan memory model.
14935                if (device->vulkan_memory_model) {
14936                    return true;
14937                } else {
14938                    return op->ne[0] <= (1 << device->max_workgroup_size_log2);
14939                }
14940            }
14941        case GGML_OP_TOP_K:
14942            {
14943                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
14944                    return false;
14945                }
14946                // We could potentially support larger, using argsort to sort the
14947                // whole thing. Not clear if this is needed.
14948                uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
14949                if (min_pipeline >= num_topk_pipelines ||
14950                    !device->pipeline_topk_f32[min_pipeline]) {
14951                    return false;
14952                }
14953            }
14954            return true;
14955        case GGML_OP_UPSCALE:
14956            if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
14957                if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
14958                    return false;
14959                }
14960            }
14961            return op->src[0]->type == GGML_TYPE_F32;
14962        case GGML_OP_ACC:
14963            return op->src[0]->type == GGML_TYPE_F32;
14964        case GGML_OP_CONCAT:
14965            return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
14966        case GGML_OP_ADD1:
14967            return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
14968                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
14969                || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
14970        case GGML_OP_ARANGE:
14971        case GGML_OP_FILL:
14972            return op->type == GGML_TYPE_F32;
14973        case GGML_OP_SCALE:
14974            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14975        case GGML_OP_PAD:
14976        case GGML_OP_ROLL:
14977            return op->src[0]->type == GGML_TYPE_F32;
14978        case GGML_OP_DIAG_MASK_INF:
14979            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14980        case GGML_OP_SOFT_MAX:
14981            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
14982                && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
14983        case GGML_OP_SOFT_MAX_BACK:
14984            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
14985                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
14986        case GGML_OP_SUM:
14987        case GGML_OP_SUM_ROWS:
14988        case GGML_OP_MEAN:
14989            return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
14990        case GGML_OP_CUMSUM:
14991            {
14992                if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
14993                    return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
14994                }
14995                return false;
14996            }
14997        case GGML_OP_SOLVE_TRI:
14998            {
14999                if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
15000                    return false;
15001                }
15002                const uint32_t N = op->src[0]->ne[0];
15003                const uint32_t K = op->src[1]->ne[0];
15004                // K dimension limited to workgroup size
15005                if (K > 1u << device->max_workgroup_size_log2) {
15006                    return false;
15007                }
15008                const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
15009
15010                if (batch_N == 0) {
15011                    return false;
15012                }
15013                return true;
15014            }
15015        case GGML_OP_ARGMAX:
15016            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
15017        case GGML_OP_COUNT_EQUAL:
15018            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
15019                && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
15020        case GGML_OP_IM2COL:
15021            return ggml_is_contiguous(op->src[1])
15022                && op->src[1]->type == GGML_TYPE_F32
15023                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
15024        case GGML_OP_IM2COL_3D:
15025            return op->src[1]->type == GGML_TYPE_F32
15026                && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
15027        case GGML_OP_TIMESTEP_EMBEDDING:
15028            return op->src[0]->type == GGML_TYPE_F32;
15029        case GGML_OP_CONV_2D_DW:
15030            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
15031                && op->src[1]->type == GGML_TYPE_F32;
15032        case GGML_OP_POOL_2D:
15033            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
15034        case GGML_OP_RWKV_WKV6:
15035        case GGML_OP_RWKV_WKV7:
15036            return true; // all inputs are contiguous, see ggml.c
15037        case GGML_OP_SSM_SCAN:
15038            {
15039                for (int i = 0; i < 6; i++) {
15040                    if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
15041                        return false;
15042                    }
15043                }
15044                if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
15045                    return false;
15046                }
15047                if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
15048                    return false;
15049                }
15050
15051                const uint32_t d_state = op->src[0]->ne[0];
15052                const uint32_t head_dim = op->src[0]->ne[1];
15053
15054                bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
15055                if (!is_mamba2) {
15056                    return false;
15057                }
15058
15059                if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
15060                    return false;
15061                }
15062
15063                size_t shmem_size = d_state * sizeof(float);
15064
15065                if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
15066                    return false;
15067                }
15068
15069                if (!device->subgroup_basic) {
15070                    return false;
15071                }
15072
15073                return true;
15074            }
15075        case GGML_OP_SSM_CONV:
15076            return op->src[0]->type == GGML_TYPE_F32;
15077        case GGML_OP_CONV_TRANSPOSE_1D:
15078            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
15079        case GGML_OP_CONV_2D:
15080        case GGML_OP_CONV_TRANSPOSE_2D:
15081            {
15082                // Channel-contiguous format is not supported yet.
15083                return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
15084                    op->src[1]->type == GGML_TYPE_F32 &&
15085                    op->type == GGML_TYPE_F32 &&
15086                    ggml_is_contiguous(op->src[0]) &&
15087                    ggml_is_contiguous(op->src[1]) &&
15088                    ggml_is_contiguous(op));
15089            }
15090        default:
15091            return false;
15092    }
15093
15094    UNUSED(dev);
15095}
15096
15097static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
15098    if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
15099        return false;
15100    }
15101
15102    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15103    ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
15104
15105    return buft_ctx->device->idx == ctx->device;
15106}
15107
15108static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
15109    ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
15110
15111    return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
15112           (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
15113}
15114
15115static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
15116    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15117    auto device = ggml_vk_get_device(ctx->device);
15118
15119    vk_event *vkev = new vk_event;
15120    if (!vkev) {
15121        return nullptr;
15122    }
15123
15124    // The event/fence is expected to initially be in the signaled state.
15125    vkev->event = device->device.createEvent({});
15126    vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
15127    device->device.setEvent(vkev->event);
15128
15129    return new ggml_backend_event {
15130        /* .device  = */ dev,
15131        /* .context = */ vkev,
15132    };
15133}
15134
15135static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
15136    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15137    auto device = ggml_vk_get_device(ctx->device);
15138
15139    vk_event *vkev = (vk_event *)event->context;
15140
15141    device->device.destroyFence(vkev->fence);
15142    device->device.destroyEvent(vkev->event);
15143    delete vkev;
15144    delete event;
15145}
15146
15147static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
15148    VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
15149    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15150    auto device = ggml_vk_get_device(ctx->device);
15151    vk_event *vkev = (vk_event *)event->context;
15152
15153    VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
15154}
15155
15156static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
15157    if (!device->external_memory_host) {
15158        return {};
15159    }
15160
15161    uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
15162    if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
15163        return {};
15164    }
15165    if (size & (device->min_imported_host_pointer_alignment - 1)) {
15166        return {};
15167    }
15168
15169    const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
15170
15171    vk_buffer buf {};
15172    try {
15173        buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
15174    } catch (vk::SystemError& e) {
15175        GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
15176    }
15177
15178    return buf;
15179}
15180
15181static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
15182    VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
15183    GGML_UNUSED(max_tensor_size);
15184
15185    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15186    auto device = ggml_vk_get_device(ctx->device);
15187
15188    vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
15189
15190    if (!buf) {
15191        return {};
15192    }
15193
15194    ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
15195
15196    ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
15197
15198    return ret;
15199}
15200
15201static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
15202    /* .get_name             = */ ggml_backend_vk_device_get_name,
15203    /* .get_description      = */ ggml_backend_vk_device_get_description,
15204    /* .get_memory           = */ ggml_backend_vk_device_get_memory,
15205    /* .get_type             = */ ggml_backend_vk_device_get_type,
15206    /* .get_props            = */ ggml_backend_vk_device_get_props,
15207    /* .init_backend         = */ ggml_backend_vk_device_init,
15208    /* .get_buffer_type      = */ ggml_backend_vk_device_get_buffer_type,
15209    /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
15210    /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
15211    /* .supports_op          = */ ggml_backend_vk_device_supports_op,
15212    /* .supports_buft        = */ ggml_backend_vk_device_supports_buft,
15213    /* .offload_op           = */ ggml_backend_vk_device_offload_op,
15214    /* .event_new            = */ ggml_backend_vk_device_event_new,
15215    /* .event_free           = */ ggml_backend_vk_device_event_free,
15216    /* .event_synchronize    = */ ggml_backend_vk_device_event_synchronize,
15217};
15218
15219static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
15220    UNUSED(reg);
15221    return GGML_VK_NAME;
15222}
15223
15224static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
15225    UNUSED(reg);
15226    return ggml_backend_vk_get_device_count();
15227}
15228
15229static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
15230    static std::vector<ggml_backend_dev_t> devices;
15231
15232    static bool initialized = false;
15233
15234    {
15235        static std::mutex mutex;
15236        std::lock_guard<std::mutex> lock(mutex);
15237        if (!initialized) {
15238            const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
15239            for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
15240                ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
15241                char desc[256];
15242                ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
15243                ctx->device = i;
15244                ctx->name = GGML_VK_NAME + std::to_string(i);
15245                ctx->description = desc;
15246                ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
15247                ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
15248                ctx->op_offload_min_batch_size = min_batch_size;
15249                devices.push_back(new ggml_backend_device {
15250                    /* .iface   = */ ggml_backend_vk_device_i,
15251                    /* .reg     = */ reg,
15252                    /* .context = */ ctx,
15253                });
15254            }
15255            initialized = true;
15256        }
15257    }
15258
15259    GGML_ASSERT(device < devices.size());
15260    return devices[device];
15261}
15262
15263static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
15264    /* .get_name         = */ ggml_backend_vk_reg_get_name,
15265    /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
15266    /* .get_device       = */ ggml_backend_vk_reg_get_device,
15267    /* .get_proc_address = */ NULL,
15268};
15269
15270ggml_backend_reg_t ggml_backend_vk_reg() {
15271    static ggml_backend_reg reg = {
15272        /* .api_version = */ GGML_BACKEND_API_VERSION,
15273        /* .iface       = */ ggml_backend_vk_reg_i,
15274        /* .context     = */ nullptr,
15275    };
15276    try {
15277        ggml_vk_instance_init();
15278        return &reg;
15279    } catch (const vk::SystemError& e) {
15280        VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
15281        return nullptr;
15282    } catch (const std::exception &e) {
15283        VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what());
15284        return nullptr;
15285    } catch (...) {
15286        VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init");
15287        return nullptr;
15288    }
15289}
15290
15291// Extension availability
15292static bool ggml_vk_instance_layer_settings_available() {
15293#ifdef GGML_VULKAN_VALIDATE
15294    // Check if validation layer provides the extension
15295    const std::string layer_name = "VK_LAYER_KHRONOS_validation";
15296    for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
15297        if (layer_name == layer.layerName.data()) {
15298            for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
15299                if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) {
15300                    return true;
15301                }
15302            }
15303        }
15304    }
15305
15306    std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl;
15307#endif
15308    return false;
15309}
15310static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
15311#ifdef __APPLE__
15312    // Check for portability enumeration extension for MoltenVK support
15313    for (const auto& properties : instance_extensions) {
15314        if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
15315            return true;
15316        }
15317    }
15318    std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
15319#endif
15320    return false;
15321
15322    UNUSED(instance_extensions);
15323}
15324
15325// Extension availability
15326static bool ggml_vk_instance_debug_utils_ext_available(
15327    const std::vector<vk::ExtensionProperties> & instance_extensions) {
15328    // Check for portability enumeration extension for MoltenVK support
15329    for (const auto & properties : instance_extensions) {
15330        if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
15331            return true;
15332        }
15333    }
15334
15335    std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
15336    return false;
15337
15338    UNUSED(instance_extensions);
15339}
15340
15341static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
15342    VkPhysicalDeviceFeatures2 device_features2;
15343    device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
15344
15345    VkPhysicalDeviceVulkan11Features vk11_features;
15346    vk11_features.pNext = nullptr;
15347    vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
15348    device_features2.pNext = &vk11_features;
15349
15350    vkGetPhysicalDeviceFeatures2(vkdev, &device_features2);
15351
15352    return vk11_features.storageBuffer16BitAccess;
15353}
15354
15355static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
15356    switch (props.vendorID) {
15357    case VK_VENDOR_ID_INTEL:
15358        // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
15359        // while some older hardware (ex. Arc A770) has performance regressions
15360        return arch == vk_device_architecture::INTEL_XE2;
15361    case VK_VENDOR_ID_AMD:
15362        if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
15363            // Workaround for AMD proprietary driver reporting support on all GPUs
15364            return arch == vk_device_architecture::AMD_RDNA3;
15365        }
15366        return true;
15367    default:
15368        return true;
15369    }
15370}
15371
15372// checks
15373
15374#ifdef GGML_VULKAN_CHECK_RESULTS
15375static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {
15376    if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
15377        return;
15378    }
15379    for (int j = 0; j < level; j++) {
15380        std::cerr << " ";
15381    }
15382    std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
15383
15384    done.push_back(tensor);
15385
15386    for (int i = 0; i < GGML_MAX_SRC; i++) {
15387        if (tensor->src[i] != nullptr) {
15388            ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
15389        }
15390    }
15391}
15392
15393static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
15394    if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
15395        return;
15396    }
15397    i0 = std::max(i0, 5);
15398    i1 = std::max(i1, 5);
15399    i2 = std::max(i2, 0);
15400    i3 = std::max(i3, 0);
15401    fprintf(stderr, "         ");
15402    for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
15403        fprintf(stderr, "%7d ", idx1);
15404    }
15405    fprintf(stderr, "\n");
15406    for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
15407        fprintf(stderr, "%7d: ", idx0);
15408        for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
15409            if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
15410                float val;
15411                if (tensor->type == GGML_TYPE_F32) {
15412                    val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
15413                } else if (tensor->type == GGML_TYPE_F16) {
15414                    val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
15415                } else if (tensor->type == GGML_TYPE_I32) {
15416                    val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
15417                } else {
15418                    GGML_ABORT("fatal error");
15419                }
15420                fprintf(stderr, "% 7.2f ", val);
15421            } else {
15422                fprintf(stderr, "        ");
15423            }
15424        }
15425        fprintf(stderr, "\n");
15426    }
15427}
15428
15429static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
15430    void * tensor_data = tensor->data;
15431
15432    const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
15433
15434    if (is_gpu) {
15435        const size_t tensor_size = ggml_nbytes(tensor);
15436        tensor_data = malloc(tensor_size);
15437
15438        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
15439
15440        vk_buffer buffer_gpu = buf_ctx->dev_buffer;
15441        ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
15442    }
15443
15444    std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
15445    std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
15446    if (tensor->src[0] != nullptr) {
15447        std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
15448    }
15449    if (tensor->src[1] != nullptr) {
15450        std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
15451    }
15452    std::cerr << std::endl << "Result:" << std::endl;
15453    ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
15454    std::cerr << std::endl;
15455    std::vector<const ggml_tensor *> done;
15456    ggml_vk_print_graph_origin(tensor, done);
15457
15458    if (is_gpu) {
15459        free(tensor_data);
15460    }
15461}
15462
15463void * comp_result;
15464size_t comp_size;
15465size_t comp_nb[GGML_MAX_DIMS];
15466size_t check_counter = 0;
15467static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
15468    ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
15469    if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
15470        return;
15471    }
15472
15473    check_counter++;
15474    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
15475        return;
15476    }
15477
15478    VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
15479
15480    struct ggml_init_params iparams = {
15481        /*.mem_size   =*/ 2ul*1024ul*1024ul*1024ul,
15482        /*.mem_buffer =*/ NULL,
15483        /*.no_alloc   =*/ false,
15484    };
15485
15486    struct ggml_context * ggml_ctx = ggml_init(iparams);
15487
15488    std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
15489    const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
15490
15491    std::map<ggml_tensor *, ggml_tensor *> cloned_tensors;
15492    std::vector<void *> cloned_mallocs;
15493
15494    struct ggml_tensor * tensor_clone = nullptr;
15495
15496    for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) {
15497        tensor = cgraph->nodes[tensor_idx + f];
15498        for (int i = 0; i < GGML_MAX_SRC; i++) {
15499            ggml_tensor * srci = tensor->src[i];
15500            if (srci == nullptr) {
15501                continue;
15502            }
15503            // If a src tensor has been cloned, use that one
15504            auto it = cloned_tensors.find(srci);
15505            if (it != cloned_tensors.end()) {
15506                src_clone[i] = it->second;
15507                continue;
15508            }
15509            ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
15510            size_t srci_size = ggml_nbytes(srci);
15511
15512            src_clone[i] = srci_clone;
15513            void *src_buffer = malloc(srci_size);
15514            cloned_mallocs.push_back(src_buffer);
15515
15516            srci_clone->data = src_buffer;
15517            if (ggml_backend_buffer_is_host(srci->buffer)) {
15518                memcpy(srci_clone->data, srci->data, srci_size);
15519                memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
15520            } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
15521                ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
15522                vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
15523                uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
15524                if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
15525                    for (int i3 = 0; i3 < srci->ne[3]; i3++) {
15526                        for (int i2 = 0; i2 < srci->ne[2]; i2++) {
15527                            const int idx = i3*srci->ne[2] + i2;
15528                            ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
15529                        }
15530                    }
15531
15532                    srci_clone->nb[0] = srci->nb[0];
15533                    srci_clone->nb[1] = srci->nb[1];
15534                    for (int i = 2; i < GGML_MAX_DIMS; i++) {
15535                        srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
15536                    }
15537                } else {
15538                    if (offset + srci_size >= buffer_gpu->size) {
15539                        srci_size = buffer_gpu->size - offset;
15540                    }
15541                    ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
15542                    memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
15543                }
15544            } else {
15545                GGML_ABORT("fatal error");
15546            }
15547
15548            if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
15549                ggml_vk_print_tensor(srci, srci_name[i]);
15550            }
15551        }
15552
15553        if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
15554            const float * params = (const float *)tensor->op_params;
15555            tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
15556            if (src_clone[4]) {
15557                ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
15558            }
15559        } else if (tensor->op == GGML_OP_MUL_MAT) {
15560            tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
15561        } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
15562            tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
15563        } else if (tensor->op == GGML_OP_SUB) {
15564            tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
15565        } else if (tensor->op == GGML_OP_MUL) {
15566            tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
15567        } else if (tensor->op == GGML_OP_DIV) {
15568            tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
15569        } else if (tensor->op == GGML_OP_CONCAT) {
15570            tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
15571        } else if (tensor->op == GGML_OP_UPSCALE) {
15572            tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
15573        } else if (tensor->op == GGML_OP_SCALE) {
15574            const float * params = (const float *)tensor->op_params;
15575            tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
15576        } else if (tensor->op == GGML_OP_ADD1) {
15577            tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
15578        } else if (tensor->op == GGML_OP_ARANGE) {
15579            const float start = ggml_get_op_params_f32(tensor, 0);
15580            const float stop = ggml_get_op_params_f32(tensor, 1);
15581            const float step = ggml_get_op_params_f32(tensor, 2);
15582            tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
15583        } else if (tensor->op == GGML_OP_FILL) {
15584            const float value = ggml_get_op_params_f32(tensor, 0);
15585            tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
15586        } else if (tensor->op == GGML_OP_SQR) {
15587            tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
15588        } else if (tensor->op == GGML_OP_SQRT) {
15589            tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
15590        } else if (tensor->op == GGML_OP_SIN) {
15591            tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
15592        } else if (tensor->op == GGML_OP_COS) {
15593            tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
15594        } else if (tensor->op == GGML_OP_LOG) {
15595            tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
15596        } else if (tensor->op == GGML_OP_TRI) {
15597            tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
15598        } else if (tensor->op == GGML_OP_DIAG) {
15599            tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
15600        } else if (tensor->op == GGML_OP_CLAMP) {
15601            const float * params = (const float *)tensor->op_params;
15602            tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
15603        } else if (tensor->op == GGML_OP_PAD) {
15604            tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
15605                                                                tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
15606        } else if (tensor->op == GGML_OP_REPEAT) {
15607            tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
15608        } else if (tensor->op == GGML_OP_REPEAT_BACK) {
15609            tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
15610        } else if (tensor->op == GGML_OP_ADD) {
15611            tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
15612        } else if (tensor->op == GGML_OP_ACC) {
15613            tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
15614        } else if (tensor->op == GGML_OP_NORM) {
15615            tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
15616        } else if (tensor->op == GGML_OP_GROUP_NORM) {
15617            const float * float_params = (const float *)tensor->op_params;
15618            tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
15619        } else if (tensor->op == GGML_OP_RMS_NORM) {
15620            tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
15621        } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
15622            const float eps = ((float *) tensor->op_params)[0];
15623            tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
15624        } else if (tensor->op == GGML_OP_SILU_BACK) {
15625            tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
15626        } else if (tensor->op == GGML_OP_L2_NORM) {
15627            const float eps = ((float *) tensor->op_params)[0];
15628            tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
15629        } else if (tensor->op == GGML_OP_SOFT_MAX) {
15630            if (tensor->src[1] != nullptr) {
15631                const float * params = (const float *)tensor->op_params;
15632                tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
15633            } else {
15634                tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
15635            }
15636        } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
15637            tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
15638        } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
15639            tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
15640        } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
15641            const int n_dims      = ((int32_t *) tensor->op_params)[1];
15642            const int mode        = ((int32_t *) tensor->op_params)[2];
15643            //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3];
15644            const int n_ctx_orig_ggml  = ((int32_t *) tensor->op_params)[4];
15645            const float freq_base       = ((float *) tensor->op_params)[5];
15646            const float freq_scale      = ((float *) tensor->op_params)[6];
15647            const float ext_factor      = ((float *) tensor->op_params)[7];
15648            const float attn_factor     = ((float *) tensor->op_params)[8];
15649            const float beta_fast       = ((float *) tensor->op_params)[9];
15650            const float beta_slow       = ((float *) tensor->op_params)[10];
15651            if (mode & GGML_ROPE_TYPE_MROPE) {
15652                int32_t *sections = ((int32_t *) tensor->op_params) + 11;
15653                if (tensor->op == GGML_OP_ROPE) {
15654                    tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15655                } else {
15656                    tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15657                }
15658            } else {
15659                if (tensor->op == GGML_OP_ROPE) {
15660                    tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15661                } else {
15662                    tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15663                }
15664            }
15665        } else if (tensor->op == GGML_OP_UNARY) {
15666            switch (ggml_get_unary_op(tensor)) {
15667            case GGML_UNARY_OP_EXP:
15668                tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
15669                break;
15670            case GGML_UNARY_OP_SILU:
15671                tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
15672                break;
15673            case GGML_UNARY_OP_GELU:
15674                tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
15675                break;
15676            case GGML_UNARY_OP_GELU_ERF:
15677                tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
15678                break;
15679            case GGML_UNARY_OP_GELU_QUICK:
15680                tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
15681                break;
15682            case GGML_UNARY_OP_RELU:
15683                tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
15684                break;
15685            case GGML_UNARY_OP_XIELU:
15686                tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
15687                ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
15688                ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
15689                ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
15690                ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
15691                break;
15692            case GGML_UNARY_OP_NEG:
15693                tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
15694                break;
15695            case GGML_UNARY_OP_TANH:
15696                tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
15697                break;
15698            case GGML_UNARY_OP_SIGMOID:
15699                tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
15700                break;
15701            case GGML_UNARY_OP_HARDSIGMOID:
15702                tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
15703                break;
15704            case GGML_UNARY_OP_HARDSWISH:
15705                tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
15706                break;
15707            case GGML_UNARY_OP_ABS:
15708                tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
15709                break;
15710            case GGML_UNARY_OP_SOFTPLUS:
15711                tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
15712                break;
15713            case GGML_UNARY_OP_STEP:
15714                tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
15715                break;
15716            case GGML_UNARY_OP_ROUND:
15717                tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
15718                break;
15719            case GGML_UNARY_OP_CEIL:
15720                tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
15721                break;
15722            case GGML_UNARY_OP_FLOOR:
15723                tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
15724                break;
15725            case GGML_UNARY_OP_TRUNC:
15726                tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
15727                break;
15728            default:
15729                std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
15730                GGML_ABORT("fatal error");
15731            }
15732        } else if (tensor->op == GGML_OP_GLU) {
15733            if (src_clone[1] == nullptr) {
15734                tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
15735            } else {
15736                tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
15737            }
15738            ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
15739            ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
15740        } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
15741            if (tensor->src[1] == nullptr) {
15742                tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
15743                tensor_clone->type = tensor->type;
15744            } else {
15745                tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
15746            }
15747        } else if (tensor->op == GGML_OP_CONT) {
15748            tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
15749        } else if (tensor->op == GGML_OP_RESHAPE) {
15750            tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
15751        } else if (tensor->op == GGML_OP_VIEW) {
15752            tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
15753        } else if (tensor->op == GGML_OP_PERMUTE) {
15754            int32_t * params = (int32_t *)tensor->op_params;
15755            tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
15756        } else if (tensor->op == GGML_OP_TRANSPOSE) {
15757            tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
15758        } else if (tensor->op == GGML_OP_GET_ROWS) {
15759            tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
15760        } else if (tensor->op == GGML_OP_ARGSORT) {
15761            tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
15762        } else if (tensor->op == GGML_OP_TOP_K) {
15763            tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
15764        } else if (tensor->op == GGML_OP_SUM) {
15765            tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
15766        } else if (tensor->op == GGML_OP_SUM_ROWS) {
15767            tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
15768        } else if (tensor->op == GGML_OP_CUMSUM) {
15769            tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
15770        } else if (tensor->op == GGML_OP_MEAN) {
15771            tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
15772        } else if (tensor->op == GGML_OP_ARGMAX) {
15773            tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
15774        } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
15775            tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
15776        } else if (tensor->op == GGML_OP_SOLVE_TRI) {
15777            tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
15778        } else if (tensor->op == GGML_OP_IM2COL) {
15779            const int32_t s0 = tensor->op_params[0];
15780            const int32_t s1 = tensor->op_params[1];
15781            const int32_t p0 = tensor->op_params[2];
15782            const int32_t p1 = tensor->op_params[3];
15783            const int32_t d0 = tensor->op_params[4];
15784            const int32_t d1 = tensor->op_params[5];
15785
15786            const bool is_2D = tensor->op_params[6] == 1;
15787            tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
15788        } else if (tensor->op == GGML_OP_IM2COL_3D) {
15789            const int32_t s0 = tensor->op_params[0];
15790            const int32_t s1 = tensor->op_params[1];
15791            const int32_t s2 = tensor->op_params[2];
15792            const int32_t p0 = tensor->op_params[3];
15793            const int32_t p1 = tensor->op_params[4];
15794            const int32_t p2 = tensor->op_params[5];
15795            const int32_t d0 = tensor->op_params[6];
15796            const int32_t d1 = tensor->op_params[7];
15797            const int32_t d2 = tensor->op_params[8];
15798            const int32_t IC = tensor->op_params[9];
15799
15800            tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
15801        } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
15802            const int32_t dim = tensor->op_params[0];
15803            const int32_t max_period = tensor->op_params[1];
15804            tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
15805        } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
15806            const int32_t s0 = tensor->op_params[0];
15807            const int32_t p0 = tensor->op_params[1];
15808            const int32_t d0 = tensor->op_params[2];
15809            tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
15810        } else if (tensor->op == GGML_OP_POOL_2D) {
15811            enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
15812            const int32_t k0 = tensor->op_params[1];
15813            const int32_t k1 = tensor->op_params[2];
15814            const int32_t s0 = tensor->op_params[3];
15815            const int32_t s1 = tensor->op_params[4];
15816            const int32_t p0 = tensor->op_params[5];
15817            const int32_t p1 = tensor->op_params[6];
15818
15819            tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
15820        } else if (tensor->op == GGML_OP_CONV_2D) {
15821            const int32_t s0 = tensor->op_params[0];
15822            const int32_t s1 = tensor->op_params[1];
15823            const int32_t p0 = tensor->op_params[2];
15824            const int32_t p1 = tensor->op_params[3];
15825            const int32_t d0 = tensor->op_params[4];
15826            const int32_t d1 = tensor->op_params[5];
15827            tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
15828        } else if (tensor->op == GGML_OP_CONV_2D_DW) {
15829            const int32_t s0 = tensor->op_params[0];
15830            const int32_t s1 = tensor->op_params[1];
15831            const int32_t p0 = tensor->op_params[2];
15832            const int32_t p1 = tensor->op_params[3];
15833            const int32_t d0 = tensor->op_params[4];
15834            const int32_t d1 = tensor->op_params[5];
15835            tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
15836        } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
15837            const int32_t s = tensor->op_params[0];
15838            tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
15839        } else if (tensor->op == GGML_OP_LEAKY_RELU) {
15840            const float * op_params = (const float *)tensor->op_params;
15841            tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
15842        } else if (tensor->op == GGML_OP_RWKV_WKV6) {
15843            tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
15844            src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
15845        } else if (tensor->op == GGML_OP_RWKV_WKV7) {
15846            tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
15847            src_clone[4], src_clone[5], src_clone[6]);
15848        } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
15849            src_clone[0]->flags = tensor->src[0]->flags;
15850            tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
15851            src_clone[2], src_clone[3], src_clone[4]);
15852        } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
15853            src_clone[0]->flags = tensor->src[0]->flags;
15854            tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
15855            src_clone[2]);
15856        } else if (tensor->op == GGML_OP_ADD_ID) {
15857            tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
15858        } else if (tensor->op == GGML_OP_SSM_SCAN) {
15859            tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
15860                                         src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
15861        } else if (tensor->op == GGML_OP_SSM_CONV) {
15862            tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
15863        } else if (tensor->op == GGML_OP_ROLL) {
15864            const int32_t s0 = tensor->op_params[0];
15865            const int32_t s1 = tensor->op_params[1];
15866            const int32_t s2 = tensor->op_params[2];
15867            const int32_t s3 = tensor->op_params[3];
15868            tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3);
15869        }
15870        else {
15871            std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
15872            GGML_ABORT("fatal error");
15873        }
15874        cloned_tensors[tensor] = tensor_clone;
15875    }
15876
15877    ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
15878    ggml_build_forward_expand(cgraph_cpu, tensor_clone);
15879
15880    ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
15881
15882    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
15883        ggml_vk_print_tensor(tensor_clone, "tensor_clone");
15884    }
15885
15886    comp_size = ggml_nbytes(tensor_clone);
15887
15888    comp_result = malloc(comp_size);
15889    memcpy(comp_result, tensor_clone->data, comp_size);
15890    memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
15891
15892    for (auto m : cloned_mallocs) {
15893        free(m);
15894    }
15895
15896    ggml_free(ggml_ctx);
15897
15898    VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
15899}
15900
15901static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
15902    ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
15903    if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
15904        return;
15905    }
15906
15907    if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
15908        return;
15909    }
15910
15911    VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
15912
15913    ggml_tensor * src0 = tensor->src[0];
15914    ggml_tensor * src1 = tensor->src[1];
15915    ggml_tensor * src2 = tensor->src[2];
15916    ggml_tensor * src3 = tensor->src[3];
15917
15918    void * tensor_data = tensor->data;
15919
15920    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
15921        size_t tensor_size = ggml_nbytes(tensor);
15922        tensor_data = malloc(tensor_size);
15923
15924        ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
15925
15926        vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
15927        uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
15928        if (offset + tensor_size >= buffer_gpu->size) {
15929            tensor_size = buffer_gpu->size - offset;
15930        }
15931
15932        ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
15933    }
15934
15935    float first_error_result = -1.0f;
15936    float first_error_correct = -1.0f;
15937    std::array<int, 4> first_error = { -1, -1, -1, -1 };
15938    double avg_err = 0.0;
15939    size_t counter = 0;
15940
15941    for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
15942        for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
15943            for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
15944                for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
15945                    const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
15946                    float correct = 0.0f;
15947                    float result = 0.0f;
15948
15949                    if (buffer_size_fit) {
15950                        if (tensor->type == GGML_TYPE_F32) {
15951                            correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
15952                            result  = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
15953                        } else if (tensor->type == GGML_TYPE_F16) {
15954                            correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
15955                            result  = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
15956                        } else if (tensor->type == GGML_TYPE_BF16) {
15957                            correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
15958                            result  = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
15959                        } else if (tensor->type == GGML_TYPE_I32) {
15960                            correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
15961                            result  = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
15962                        } else if (tensor->type == GGML_TYPE_I64) {
15963                            correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
15964                            result  = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
15965                        } else {
15966                            std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
15967                        }
15968                    } else {
15969                        std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
15970                        GGML_ABORT("fatal error");
15971                    }
15972
15973                    if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
15974                        std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
15975                        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
15976                        if (src0 != nullptr) {
15977                            std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
15978                        }
15979                        if (src1 != nullptr) {
15980                            std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
15981                        }
15982                        if (src2 != nullptr) {
15983                            std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
15984                        }
15985                        if (src3 != nullptr) {
15986                            std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
15987                        }
15988                        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
15989                        std::cerr << std::endl << "Result:" << std::endl;
15990                        ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
15991                        std::cerr << std::endl << "Correct:" << std::endl;
15992                        ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
15993                        std::cerr << std::endl;
15994                        std::vector<const ggml_tensor *> done;
15995                        ggml_vk_print_graph_origin(tensor, done);
15996                        GGML_ABORT("fatal error");
15997                    }
15998                    const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
15999                    if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
16000                        first_error[0] = i0;
16001                        first_error[1] = i1;
16002                        first_error[2] = i2;
16003                        first_error[3] = i3;
16004                        first_error_result = result;
16005                        first_error_correct = correct;
16006                    }
16007
16008                    // Special case, value is infinite, avoid NaN result in avg_err
16009                    // NaN also appears in results, if both are nan error is 0
16010                    if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
16011                        avg_err += std::fabs(correct - result) / denom;
16012                    }
16013                    counter++;
16014                }
16015            }
16016        }
16017    }
16018
16019    avg_err /= counter;
16020
16021    if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
16022        std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
16023        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
16024        if (src0 != nullptr) {
16025            std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
16026        }
16027        if (src1 != nullptr) {
16028            std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
16029        }
16030        if (src2 != nullptr) {
16031            std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
16032        }
16033        if (src3 != nullptr) {
16034            std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
16035        }
16036        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
16037        std::cerr << std::endl << "Result:" << std::endl;
16038        ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
16039        std::cerr << std::endl << "Correct:" << std::endl;
16040        ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
16041        std::cerr << std::endl;
16042        std::vector<const ggml_tensor *> done;
16043        ggml_vk_print_graph_origin(tensor, done);
16044    }
16045
16046    if (avg_err > 0.5 || std::isnan(avg_err)) {
16047        std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
16048        std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
16049        if (src0 != nullptr) {
16050            std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
16051        }
16052        if (src1 != nullptr) {
16053            std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
16054        }
16055        if (src2 != nullptr) {
16056            std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
16057        }
16058        if (src3 != nullptr) {
16059            std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
16060        }
16061        std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct  << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
16062        std::cerr << std::endl << "Result:" << std::endl;
16063        ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
16064        std::cerr << std::endl << "Correct:" << std::endl;
16065        ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
16066        std::cerr << std::endl;
16067        std::vector<const ggml_tensor *> done;
16068        ggml_vk_print_graph_origin(tensor, done);
16069        GGML_ABORT("fatal error");
16070    } else {
16071        std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
16072    }
16073
16074    free(comp_result);
16075    comp_result = nullptr;
16076    comp_size = 0;
16077
16078    if (ggml_backend_buffer_is_vk(tensor->buffer)) {
16079        free(tensor_data);
16080    }
16081
16082    VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
16083}
16084#endif
16085
16086GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)