1#include "ggml-vulkan.h"
2#include <vulkan/vulkan_core.h>
3#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS)
4#include <chrono>
5#include "ggml-cpu.h"
6#endif
7
8// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
9#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1
10// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE
11// to avoid conflicts with applications or other libraries who might use it.
12#if VK_HEADER_VERSION >= 301
13namespace vk::detail { class DispatchLoaderDynamic; }
14using vk::detail::DispatchLoaderDynamic;
15#else
16namespace vk { class DispatchLoaderDynamic; }
17using vk::DispatchLoaderDynamic;
18#endif
19DispatchLoaderDynamic & ggml_vk_default_dispatcher();
20#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher()
21
22#include <vulkan/vulkan.hpp>
23
24#include <algorithm>
25#include <cmath>
26#include <iomanip>
27#include <iostream>
28#include <tuple>
29#include <vector>
30#include <sstream>
31#include <utility>
32#include <memory>
33#include <limits>
34#include <map>
35#include <set>
36#include <unordered_map>
37#include <memory>
38#include <mutex>
39#include <future>
40#include <thread>
41
42#if defined(_MSC_VER)
43# define NOMINMAX 1
44# include <windows.h>
45# define YIELD() YieldProcessor()
46#elif defined(__clang__) || defined(__GNUC__)
47# if defined(__x86_64__) ||defined(__i386__)
48# include <immintrin.h>
49# define YIELD() _mm_pause()
50# elif defined(__arm__) || defined(__aarch64__)
51# if defined(__clang__)
52# include <arm_acle.h>
53# define YIELD() __yield()
54# else
55# define YIELD() asm volatile("yield")
56# endif
57# endif
58#endif
59
60#if !defined(YIELD)
61#define YIELD()
62#endif
63
64#include "ggml-impl.h"
65#include "ggml-backend-impl.h"
66
67#include "ggml-vulkan-shaders.hpp"
68
69// remove this once it's more widely available in the SDK
70#if !defined(VK_KHR_shader_bfloat16)
71
72#define VK_KHR_shader_bfloat16 1
73#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
74#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
75#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
76#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
77
78typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
79 VkStructureType sType;
80 void* pNext;
81 VkBool32 shaderBFloat16Type;
82 VkBool32 shaderBFloat16DotProduct;
83 VkBool32 shaderBFloat16CooperativeMatrix;
84} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
85#endif
86
87#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
88#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
89static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
90
91#define VK_VENDOR_ID_AMD 0x1002
92#define VK_VENDOR_ID_APPLE 0x106b
93#define VK_VENDOR_ID_INTEL 0x8086
94#define VK_VENDOR_ID_NVIDIA 0x10de
95
96#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
97
98#define GGML_VK_MAX_NODES 8192
99
100#define VK_CHECK(err, msg) \
101 do { \
102 vk::Result err_ = (err); \
103 if (err_ != vk::Result::eSuccess) { \
104 fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \
105 #err, to_string(err_).c_str(), __FILE__, __LINE__); \
106 exit(1); \
107 } \
108 } while (0)
109
110#ifdef GGML_VULKAN_DEBUG
111#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl
112#else
113#define VK_LOG_DEBUG(msg) ((void) 0)
114#endif // GGML_VULKAN_DEBUG
115
116struct ggml_backend_vk_context;
117
118#define MAX_PARAMETER_COUNT 12
119// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT.
120#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3)
121
122typedef std::shared_ptr<struct vk_pipeline_struct> vk_pipeline;
123
124struct vk_pipeline_struct {
125 std::string name;
126 vk::ShaderModule shader_module;
127 vk::PipelineLayout layout;
128 vk::Pipeline pipeline;
129 uint32_t push_constant_size;
130 uint32_t parameter_count;
131 std::array<uint32_t, 3> wg_denoms;
132 uint32_t align;
133 // true if fields have been set by ggml_vk_create_pipeline
134 bool initialized {};
135 // set to true to request the pipeline is compiled
136 std::atomic<bool> needed {};
137 // set to true when the shader has been compiled
138 std::atomic<bool> compiled {};
139 // number of registers used, extracted from pipeline executable properties
140 uint32_t register_count {};
141
142#if defined(VK_EXT_shader_64bit_indexing)
143 bool is_64b_indexing {};
144#endif
145 // linked list of pipelines for multiple compilation variants.
146 // currently only used to compile a 64-bit indexing variant.
147 vk_pipeline next;
148};
149
150typedef std::weak_ptr<vk_pipeline_struct> vk_pipeline_ref;
151
152static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline);
153
154struct vk_matmul_pipeline_struct {
155 vk_pipeline l, m, s;
156 vk_pipeline a_l, a_m, a_s;
157 // Returns true when all unaligned pipelines are null.
158 // We only check for unaligned variants since one of the unaligned pipelines must exist
159 // while aligned pipelines are optional
160 bool is_empty() const {
161 return l == nullptr && m == nullptr && s == nullptr;
162 }
163};
164typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
165
166struct vk_matmul_pipeline2 {
167 vk_matmul_pipeline2() {
168 f16acc = std::make_shared<vk_matmul_pipeline_struct>();
169 f32acc = std::make_shared<vk_matmul_pipeline_struct>();
170 }
171 vk_matmul_pipeline f32acc;
172 vk_matmul_pipeline f16acc;
173};
174
175struct vk_device_struct;
176typedef std::shared_ptr<vk_device_struct> vk_device;
177typedef std::weak_ptr<vk_device_struct> vk_device_ref;
178
179struct vk_buffer_struct;
180typedef std::shared_ptr<vk_buffer_struct> vk_buffer;
181typedef std::weak_ptr<vk_buffer_struct> vk_buffer_ref;
182
183struct ggml_backend_vk_buffer_type_context {
184 std::string name;
185 vk_device device;
186};
187
188struct vk_queue;
189
190// Stores command pool/buffers. There's an instance of this
191// for each (context,queue) pair and for each (device,queue) pair.
192struct vk_command_pool {
193 void init(vk_device& device, vk_queue *q_);
194 void destroy(vk::Device& device);
195
196 vk::CommandPool pool;
197 uint32_t cmd_buffer_idx;
198 std::vector<vk::CommandBuffer> cmd_buffers;
199
200 vk_queue *q;
201};
202
203// Prevent simultaneous submissions to the same queue.
204// This could be per vk_queue if we stopped having two vk_queue structures
205// sharing the same vk::Queue.
206static std::mutex queue_mutex;
207
208struct vk_queue {
209 uint32_t queue_family_index;
210 vk::Queue queue;
211
212 vk_command_pool cmd_pool;
213
214 vk::PipelineStageFlags stage_flags;
215
216 bool transfer_only;
217
218 // copy everything except the cmd_pool
219 void copyFrom(vk_queue &other) {
220 queue_family_index = other.queue_family_index;
221 queue = other.queue;
222 stage_flags = other.stage_flags;
223 transfer_only = other.transfer_only;
224 }
225};
226
227static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft);
228static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
229static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft);
230static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft);
231static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor);
232static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = {
233 /* .get_name = */ ggml_backend_vk_buffer_type_name,
234 /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer,
235 /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment,
236 /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size,
237 /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size,
238 /* .is_host = */ NULL,
239};
240
241class vk_memory_logger;
242class vk_perf_logger;
243static void ggml_vk_destroy_buffer(vk_buffer& buf);
244static void ggml_vk_synchronize(ggml_backend_vk_context * ctx);
245
246static constexpr uint32_t mul_mat_vec_max_cols = 8;
247static constexpr uint32_t p021_max_gqa_ratio = 8;
248
249enum vk_device_architecture {
250 OTHER,
251 AMD_GCN,
252 AMD_RDNA1,
253 AMD_RDNA2,
254 AMD_RDNA3,
255 INTEL_XE2,
256 NVIDIA_PRE_TURING,
257 NVIDIA_TURING,
258};
259
260static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
261 vk::PhysicalDeviceProperties props = device.getProperties();
262
263 if (props.vendorID == VK_VENDOR_ID_AMD) {
264 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
265
266 bool amd_shader_core_properties = false;
267 bool integer_dot_product = false;
268 bool subgroup_size_control = false;
269
270 for (const auto& properties : ext_props) {
271 if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
272 amd_shader_core_properties = true;
273 } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
274 integer_dot_product = true;
275 } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
276 subgroup_size_control = true;
277 }
278 }
279
280 if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
281 return vk_device_architecture::OTHER;
282 }
283
284 vk::PhysicalDeviceProperties2 props2;
285 vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
286 vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
287 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
288
289 props2.pNext = &shader_core_props_amd;
290 shader_core_props_amd.pNext = &integer_dot_props;
291 integer_dot_props.pNext = &subgroup_size_control_props;
292
293 device.getProperties2(&props2);
294
295 if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
296 return vk_device_architecture::AMD_GCN;
297 }
298 if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
299 // RDNA
300 if (shader_core_props_amd.wavefrontsPerSimd == 20) {
301 return vk_device_architecture::AMD_RDNA1;
302 }
303 if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
304 return vk_device_architecture::AMD_RDNA3;
305 }
306 return vk_device_architecture::AMD_RDNA2;
307 }
308 } else if (props.vendorID == VK_VENDOR_ID_INTEL) {
309 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
310
311 bool subgroup_size_control = false;
312
313 for (const auto& properties : ext_props) {
314 if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
315 subgroup_size_control = true;
316 }
317 }
318
319 if (!subgroup_size_control) {
320 return vk_device_architecture::OTHER;
321 }
322
323 vk::PhysicalDeviceProperties2 props2;
324 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
325
326 props2.pNext = &subgroup_size_control_props;
327 device.getProperties2(&props2);
328
329 if (subgroup_size_control_props.minSubgroupSize == 16) {
330 // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8.
331 // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value.
332 // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html
333 // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html
334 return vk_device_architecture::INTEL_XE2;
335 }
336 } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) {
337 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
338
339 bool cooperative_matrix = false;
340 bool sm_builtins = false;
341
342 // Detect "pre-turing" based on lack of coopmat support.
343 for (const auto& properties : ext_props) {
344 if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
345 cooperative_matrix = true;
346 } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
347 sm_builtins = true;
348 }
349 }
350
351 if (!cooperative_matrix) {
352 return vk_device_architecture::NVIDIA_PRE_TURING;
353 }
354
355 if (sm_builtins) {
356 vk::PhysicalDeviceProperties2 props2;
357 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
358
359 props2.pNext = &sm_props;
360
361 device.getProperties2(&props2);
362
363 // Turing has 32, following architectures have 48
364 if (sm_props.shaderWarpsPerSM == 32) {
365 return vk_device_architecture::NVIDIA_TURING;
366 }
367 }
368 }
369 return vk_device_architecture::OTHER;
370}
371
372enum vk_conv_shapes {
373 CONV_SHAPE_128x128,
374 CONV_SHAPE_64x32,
375 CONV_SHAPE_32x256,
376 CONV_SHAPE_COUNT,
377};
378
379struct vk_conv_block_size {
380 uint32_t K;
381 uint32_t NPQ;
382 uint32_t CRS;
383};
384
385vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
386 // K NPQ CRS
387 { 128, 128, 16 }, // CONV_SHAPE_128x128
388 { 64, 32, 32 }, // CONV_SHAPE_64x32
389 { 32, 256, 16 }, // CONV_SHAPE_32x256
390};
391
392enum dmmv_wg_sizes {
393 DMMV_WG_SIZE_SUBGROUP,
394 DMMV_WG_SIZE_LARGE,
395 DMMV_WG_SIZE_COUNT,
396};
397
398enum FaCodePath {
399 FA_SCALAR,
400 FA_COOPMAT1,
401 FA_COOPMAT2,
402};
403
404struct vk_fa_pipeline_state {
405 vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
406 : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
407
408 uint32_t HSK, HSV;
409 bool small_rows, small_cache;
410 FaCodePath path;
411 bool aligned;
412 bool f32acc;
413 uint32_t flags;
414
415 bool operator<(const vk_fa_pipeline_state &b) const {
416 return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
417 std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
418 }
419};
420
421struct vk_conv2d_pipeline_state {
422 vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
423 : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
424
425 uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
426
427 bool operator<(const vk_conv2d_pipeline_state &b) const {
428 return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
429 std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
430 }
431};
432
433struct vk_solve_tri_pipeline_state {
434 vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
435 : N(N), K(K) {}
436
437 uint32_t N, K;
438
439 bool operator<(const vk_solve_tri_pipeline_state &b) const {
440 return std::tie(N, K) <
441 std::tie(b.N, b.K);
442 }
443};
444
445enum shader_reduction_mode {
446 SHADER_REDUCTION_MODE_SHMEM,
447 SHADER_REDUCTION_MODE_HYBRID,
448 SHADER_REDUCTION_MODE_SUBGROUP,
449 SHADER_REDUCTION_MODE_COUNT,
450};
451
452// argsort pipelines for up to 1<<10 invocations per workgroup
453static constexpr uint32_t num_argsort_pipelines = 11;
454static constexpr uint32_t num_topk_moe_pipelines = 10;
455static constexpr uint32_t num_topk_pipelines = 11;
456
457static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
458 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
459 GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
460 GGML_OP_RESHAPE };
461
462static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
463 GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
464 GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
465 GGML_OP_DIV, GGML_OP_RESHAPE };
466
467static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
468 GGML_OP_VIEW, GGML_OP_GET_ROWS };
469
470static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
471 GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
472 GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
473
474//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
475//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
476//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
477//node #981 ( VIEW): ffn_moe_topk-15 ( 0K) [Vulka ] use=4: ffn_moe_argsort-15 ( 0K) [Vulka ]
478//node #982 ( GET_ROWS): ffn_moe_weights-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 (re ( 0K) [Vulka ] ffn_moe_topk-15 ( 0K) [Vulka ]
479//node #983 ( RESHAPE): ffn_moe_weights-15 ( ( 0K) [Vulka ] use=2: ffn_moe_weights-15 ( 0K) [Vulka ]
480//node #984 ( SUM_ROWS): ffn_moe_weights_sum- ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ]
481//node #985 ( CLAMP): ffn_moe_weights_sum_ ( 0K) [Vulka ] use=1: ffn_moe_weights_sum- ( 0K) [Vulka ]
482//node #986 ( DIV): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights-15 ( ( 0K) [Vulka ] ffn_moe_weights_sum_ ( 0K) [Vulka ]
483//node #987 ( RESHAPE): ffn_moe_weights_norm ( 0K) [Vulka ] use=1: ffn_moe_weights_norm ( 0K) [Vulka ]
484static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
485 { 1, 0, 0 }, // reshape->src[0] == softmax
486 { 2, 0, 0 }, // argsort->src[0] == softmax
487 { 3, 0, 2 }, // view->src[0] == argsort
488 { 4, 0, 1 }, // get_rows->src[0] == reshape
489 { 4, 1, 3 }, // get_rows->src[1] == view
490 { 5, 0, 4 }, // reshape->src[0] == get_rows
491 { 6, 0, 5 }, // sum_rows->src[0] == reshape
492 { 7, 0, 6 }, // clamp->src[0] == sum_rows
493 { 8, 0, 5 }, // div->src[0] == reshape
494 { 8, 1, 7 }, // div->src[1] == clamp
495 { 9, 0, 8 }, // reshape->src[0] == div
496};
497
498//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
499//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
500//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
501//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
502//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
503//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
504//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
505//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
506//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
507//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
508//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
509static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
510 { 1, 0, 0 }, // reshape->src[0] == sigmoid
511 { 2, 0, 0 }, // add->src[0] == sigmoid
512 { 3, 0, 2 }, // argsort->src[0] == add
513 { 4, 0, 3 }, // view->src[0] == argsort
514 { 5, 0, 1 }, // get_rows->src[0] == reshape
515 { 5, 1, 4 }, // get_rows->src[1] == view
516 { 6, 0, 5 }, // reshape->src[0] == get_rows
517 { 7, 0, 6 }, // sum_rows->src[0] == reshape
518 { 8, 0, 7 }, // clamp->src[0] == sum_rows
519 { 9, 0, 6 }, // div->src[0] == reshape
520 { 9, 1, 8 }, // div->src[1] == clamp
521 {10, 0, 9 }, // reshape->src[0] == div
522};
523
524// same as early_softmax_norm but ending after the get_rows
525static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
526 { 1, 0, 0 }, // reshape->src[0] == softmax
527 { 2, 0, 0 }, // argsort->src[0] == softmax
528 { 3, 0, 2 }, // view->src[0] == argsort
529 { 4, 0, 1 }, // get_rows->src[0] == reshape
530 { 4, 1, 3 }, // get_rows->src[1] == view
531};
532
533//node #652 ( ARGSORT): ffn_moe_argsort-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 ( 0K) [Vulka ]
534//node #653 ( VIEW): ffn_moe_topk-11 ( 0K) [Vulka ] use=7: ffn_moe_argsort-11 ( 0K) [Vulka ]
535//node #654 ( GET_ROWS): ffn_moe_weights-11 ( 0K) [Vulka ] use=1: ffn_moe_probs-11 (re ( 0K) [Vulka ] ffn_moe_topk-11 ( 0K) [Vulka ]
536//node #655 ( RESHAPE): ffn_moe_weights-11 ( ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( 0K) [Vulka ]
537//node #656 ( SOFT_MAX): node_656 ( 0K) [Vulka ] use=1: ffn_moe_weights-11 ( ( 0K) [Vulka ]
538//node #657 ( RESHAPE): ffn_moe_weights_soft ( 0K) [Vulka ] use=1: node_656 ( 0K) [Vulka ]
539static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
540 { 1, 0, 0 }, // view->src[0] == argsort
541 { 2, 1, 1 }, // get_rows->src[1] == view
542 { 3, 0, 2 }, // reshape->src[0] == get_rows
543 { 4, 0, 3 }, // soft_max->src[0] == reshape
544 { 5, 0, 4 }, // reshape->src[0] == soft_max
545};
546
547enum topk_moe_mode {
548 TOPK_MOE_EARLY_SOFTMAX,
549 TOPK_MOE_EARLY_SOFTMAX_NORM,
550 TOPK_MOE_LATE_SOFTMAX,
551 TOPK_MOE_SIGMOID_NORM_BIAS,
552 TOPK_MOE_COUNT,
553};
554
555static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
556 { 1, 0, 0 }, // view->src[0] == rope
557 { 2, 0, 1 }, // set_rows->src[0] == view
558};
559
560static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_view_set_rows_edges {
561 { 1, 0, 0 }, // mul->src[0] == rms
562 { 2, 0, 1 }, // rope->src[0] == mul
563 { 3, 0, 2 }, // view->src[0] == rope
564 { 4, 0, 3 }, // set_rows->src[0] == view
565};
566
567
568struct vk_device_struct {
569 std::recursive_mutex mutex;
570
571 vk::PhysicalDevice physical_device;
572 vk::PhysicalDeviceProperties properties;
573 std::string name;
574 uint64_t max_memory_allocation_size;
575 uint64_t max_buffer_size;
576 uint64_t suballocation_block_size;
577 uint64_t min_imported_host_pointer_alignment;
578 bool external_memory_host {};
579 bool fp16;
580 bool bf16;
581 bool pipeline_robustness;
582 bool memory_priority;
583 vk::Device device;
584 uint32_t vendor_id;
585 vk::DriverId driver_id;
586 vk_device_architecture architecture;
587 vk_queue compute_queue;
588 vk_queue transfer_queue;
589 bool single_queue;
590 bool support_async;
591 uint32_t subgroup_size;
592 uint32_t subgroup_size_log2;
593 uint32_t shader_core_count;
594 bool uma;
595 bool prefer_host_memory;
596 bool float_controls_rte_fp16;
597 bool subgroup_basic;
598 bool subgroup_arithmetic;
599 bool subgroup_shuffle;
600 bool subgroup_ballot;
601 bool subgroup_clustered;
602 bool subgroup_vote;
603 bool multi_add;
604 bool shader_int64;
605 bool buffer_device_address;
606 bool vulkan_memory_model;
607
608 bool add_rms_fusion;
609 uint32_t partials_binding_alignment;
610
611 bool shader_64b_indexing;
612
613 bool integer_dot_product;
614 // 0: default, 1: force mmvq, -1: disable mmvq
615 int32_t mmvq_mode;
616
617 bool subgroup_size_control;
618 uint32_t subgroup_min_size;
619 uint32_t subgroup_max_size;
620 bool subgroup_require_full_support;
621
622 // floor(log2(maxComputeWorkGroupInvocations))
623 uint32_t max_workgroup_size_log2 {};
624
625 bool coopmat_support;
626 bool coopmat_acc_f32_support {};
627 bool coopmat_acc_f16_support {};
628 bool coopmat_bf16_support {};
629 bool coopmat_support_16x16x16_f16acc {};
630 bool coopmat_support_16x16x16_f32acc {};
631 bool coopmat1_fa_support {};
632 uint32_t coopmat_m;
633 uint32_t coopmat_n;
634 uint32_t coopmat_k;
635
636 bool coopmat_int_support;
637 uint32_t coopmat_int_m;
638 uint32_t coopmat_int_n;
639 uint32_t coopmat_int_k;
640
641 bool coopmat2;
642
643 bool pipeline_executable_properties_support {};
644
645 size_t idx;
646
647 bool mul_mat_l[GGML_TYPE_COUNT];
648 bool mul_mat_m[GGML_TYPE_COUNT];
649 bool mul_mat_s[GGML_TYPE_COUNT];
650 bool mul_mat_id_l[GGML_TYPE_COUNT];
651 bool mul_mat_id_m[GGML_TYPE_COUNT];
652 bool mul_mat_id_s[GGML_TYPE_COUNT];
653
654 vk::DescriptorSetLayout dsl;
655
656 vk_matmul_pipeline pipeline_matmul_f32 {};
657 vk_matmul_pipeline pipeline_matmul_f32_f16 {};
658 vk_matmul_pipeline pipeline_matmul_bf16 {};
659 vk_matmul_pipeline2 pipeline_matmul_f16;
660 vk_matmul_pipeline2 pipeline_matmul_f16_f32;
661
662 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
663 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
664 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
665
666 vk_matmul_pipeline pipeline_matmul_id_f32 {};
667 vk_matmul_pipeline pipeline_matmul_id_bf16 {};
668 vk_matmul_pipeline2 pipeline_matmul_id_f16;
669 vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
670
671 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
672 vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_COUNT];
673
674 vk_pipeline pipeline_matmul_split_k_reduce;
675 vk_pipeline pipeline_quantize_q8_1_x4;
676
677 vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
678 vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
679 vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
680 vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
681
682 vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
683 vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
684
685 vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
686 vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
687 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
688 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
689 vk_pipeline pipeline_acc_f32;
690
691 // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
692 vk_pipeline pipeline_add[2][2][2];
693 vk_pipeline pipeline_add_norepeat[2][2][2];
694 vk_pipeline pipeline_sub[2][2][2];
695 vk_pipeline pipeline_sub_norepeat[2][2][2];
696 vk_pipeline pipeline_mul[2][2][2];
697 vk_pipeline pipeline_mul_norepeat[2][2][2];
698 vk_pipeline pipeline_div[2][2][2];
699 vk_pipeline pipeline_div_norepeat[2][2][2];
700 vk_pipeline pipeline_add_rms[2][2][2];
701 vk_pipeline pipeline_add_rms_norepeat[2][2][2];
702
703 // indexed by num_additional_fused_ops == num_adds - 1
704 vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS];
705 vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS];
706
707 vk_pipeline pipeline_add_id_f32;
708
709 vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
710 vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32;
711 vk_pipeline pipeline_scale_f32;
712 vk_pipeline pipeline_sqr_f32;
713 vk_pipeline pipeline_sqrt_f32;
714 vk_pipeline pipeline_sin_f32;
715 vk_pipeline pipeline_cos_f32;
716 vk_pipeline pipeline_log[2];
717 vk_pipeline pipeline_tri[2];
718 vk_pipeline pipeline_diag[2];
719 vk_pipeline pipeline_clamp_f32;
720 vk_pipeline pipeline_pad_f32;
721 vk_pipeline pipeline_roll_f32;
722 vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
723 vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
724 vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
725 vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
726 vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
727 vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
728 vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
729 vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
730 vk_pipeline pipeline_norm_f32;
731 vk_pipeline pipeline_group_norm_f32;
732 vk_pipeline pipeline_rms_norm_f32;
733 vk_pipeline pipeline_rms_norm_mul_f32;
734 vk_pipeline pipeline_rms_norm_partials_f32;
735 vk_pipeline pipeline_rms_norm_mul_partials_f32;
736 vk_pipeline pipeline_rms_norm_mul_rope_f32_f32;
737 vk_pipeline pipeline_rms_norm_mul_rope_f32_f16;
738 vk_pipeline pipeline_rms_norm_back_f32;
739 vk_pipeline pipeline_l2_norm_f32;
740
741 // [src/dst 0=fp32,1=fp16]
742 vk_pipeline pipeline_exp[2];
743 vk_pipeline pipeline_gelu[2];
744 vk_pipeline pipeline_gelu_erf[2];
745 vk_pipeline pipeline_gelu_quick[2];
746 vk_pipeline pipeline_silu[2];
747 vk_pipeline pipeline_relu[2];
748 vk_pipeline pipeline_xielu[2];
749 vk_pipeline pipeline_neg[2];
750 vk_pipeline pipeline_tanh[2];
751 vk_pipeline pipeline_sigmoid[2];
752 vk_pipeline pipeline_hardsigmoid[2];
753 vk_pipeline pipeline_hardswish[2];
754 vk_pipeline pipeline_abs[2];
755 vk_pipeline pipeline_softplus[2];
756 vk_pipeline pipeline_step[2];
757 vk_pipeline pipeline_round[2];
758 vk_pipeline pipeline_ceil[2];
759 vk_pipeline pipeline_floor[2];
760 vk_pipeline pipeline_trunc[2];
761
762 vk_pipeline pipeline_add1_f16_f16;
763 vk_pipeline pipeline_add1_f16_f32;
764 vk_pipeline pipeline_add1_f32_f32;
765
766 vk_pipeline pipeline_arange_f32;
767
768 vk_pipeline pipeline_fill_f32;
769
770 vk_pipeline pipeline_geglu[2];
771 vk_pipeline pipeline_reglu[2];
772 vk_pipeline pipeline_swiglu[2];
773 vk_pipeline pipeline_swiglu_oai[2];
774 vk_pipeline pipeline_geglu_erf[2];
775 vk_pipeline pipeline_geglu_quick[2];
776
777 vk_pipeline pipeline_leaky_relu_f32;
778 vk_pipeline pipeline_silu_back_f32;
779 vk_pipeline pipeline_diag_mask_inf_f32;
780 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
781 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
782 vk_pipeline pipeline_soft_max_back_f32;
783
784 vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
785 vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
786 vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
787
788 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
789 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
790 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16, pipeline_rope_multi_f32_f16;
791 vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
792 vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
793 vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
794 vk_pipeline pipeline_topk_f32[num_topk_pipelines];
795 vk_pipeline pipeline_sum_rows_f32;
796 vk_pipeline pipeline_cumsum_f32;
797 vk_pipeline pipeline_cumsum_small_f32;
798 vk_pipeline pipeline_cumsum_multipass1_f32;
799 vk_pipeline pipeline_cumsum_multipass2_f32;
800 vk_pipeline pipeline_argmax_f32;
801 vk_pipeline pipeline_count_equal_i32;
802 std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
803 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
804 vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
805 vk_pipeline pipeline_timestep_embedding_f32;
806 vk_pipeline pipeline_conv_transpose_1d_f32;
807 vk_pipeline pipeline_pool2d_f32;
808 vk_pipeline pipeline_rwkv_wkv6_f32;
809 vk_pipeline pipeline_rwkv_wkv7_f32;
810 vk_pipeline pipeline_ssm_scan_f32_d128;
811 vk_pipeline pipeline_ssm_scan_f32_d256;
812 vk_pipeline pipeline_ssm_conv_f32;
813 vk_pipeline pipeline_opt_step_adamw_f32;
814 vk_pipeline pipeline_opt_step_sgd_f32;
815 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
816 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
817 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT];
818 std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT];
819 vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
820 vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
821
822 std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
823
824 std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
825
826 vk_pipeline pipeline_flash_attn_split_k_reduce;
827 vk_pipeline pipeline_count_experts;
828
829 // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
830 vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
831
832 std::vector<vk_pipeline_ref> all_pipelines;
833
834 std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
835
836 vk::Fence fence;
837 vk_buffer sync_staging;
838
839 ggml_backend_buffer_type buffer_type;
840
841 bool disable_fusion;
842 bool disable_host_visible_vidmem;
843 bool allow_sysmem_fallback;
844 bool disable_graph_optimize;
845
846 std::unique_ptr<vk_memory_logger> memory_logger;
847
848 ~vk_device_struct() {
849 VK_LOG_DEBUG("destroy device " << name);
850
851 device.destroyFence(fence);
852
853 ggml_vk_destroy_buffer(sync_staging);
854
855 compute_queue.cmd_pool.destroy(device);
856 transfer_queue.cmd_pool.destroy(device);
857
858 for (auto& pipeline : all_pipelines) {
859 if (pipeline.expired()) {
860 continue;
861 }
862
863 vk_pipeline pl = pipeline.lock();
864 ggml_vk_destroy_pipeline(device, pl);
865 }
866 all_pipelines.clear();
867
868 device.destroyDescriptorSetLayout(dsl);
869
870 device.destroy();
871 }
872};
873
874void vk_command_pool::init(vk_device& device, vk_queue *q_) {
875 cmd_buffer_idx = 0;
876 q = q_;
877
878 vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
879 pool = device->device.createCommandPool(command_pool_create_info);
880}
881
882void vk_command_pool::destroy(vk::Device& device) {
883 device.destroyCommandPool(pool);
884 pool = nullptr;
885 cmd_buffers.clear();
886}
887
888struct vk_buffer_struct {
889 vk::Buffer buffer = VK_NULL_HANDLE;
890 vk::DeviceMemory device_memory = VK_NULL_HANDLE;
891 vk::MemoryPropertyFlags memory_property_flags;
892 void * ptr;
893 size_t size = 0;
894 vk::DeviceAddress bda_addr {};
895
896 vk_device device;
897
898 ~vk_buffer_struct() {
899 if (size == 0) {
900 return;
901 }
902 VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")");
903
904 device->device.freeMemory(device_memory);
905 device->device.destroyBuffer(buffer);
906 }
907};
908
909struct vk_subbuffer {
910 vk_buffer buffer;
911 uint64_t offset;
912 uint64_t size;
913
914 operator vk::DescriptorBufferInfo() const {
915 return { buffer->buffer, offset, size };
916 }
917};
918
919// vk_event is used for the event-related backend interfaces. It uses 'event' for
920// event_wait and 'fence' for event_synchronize. Polling on an event for
921// event_synchronize wouldn't be sufficient to wait for command buffers to complete,
922// and would lead to validation errors.
923struct vk_event {
924 vk::Event event;
925 vk::Fence fence;
926};
927
928struct vk_semaphore {
929 vk::Semaphore s;
930 uint64_t value;
931};
932
933struct vk_submission {
934 vk::CommandBuffer buffer;
935 std::vector<vk_semaphore> wait_semaphores;
936 std::vector<vk_semaphore> signal_semaphores;
937};
938
939typedef std::vector<vk_submission> vk_sequence;
940
941struct vk_mat_mat_push_constants {
942 uint32_t M; uint32_t N; uint32_t K;
943 uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
944 uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
945 uint32_t k_split;
946 uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
947 uint32_t padded_N;
948};
949
950#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
951#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
952#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
953#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
954
955struct vk_mat_vec_push_constants {
956 uint32_t ncols;
957 uint32_t stride_a;
958 uint32_t stride_b;
959 uint32_t stride_d;
960 uint32_t batch_stride_a;
961 uint32_t batch_stride_b;
962 uint32_t batch_stride_d;
963 uint32_t fusion_flags;
964 uint32_t ne02;
965 uint32_t ne12;
966 uint32_t broadcast2;
967 uint32_t broadcast3;
968};
969
970struct vk_mat_vec_p021_push_constants {
971 uint32_t ncols_x;
972 uint32_t nrows_x;
973 uint32_t nchannels_x;
974 uint32_t nchannels_y;
975 uint32_t b_offset;
976 uint32_t d_offset;
977 uint32_t fusion_flags;
978};
979
980struct vk_mat_vec_nc_push_constants {
981 uint32_t ncols_x;
982 uint32_t nrows_x;
983 uint32_t row_stride_x;
984 uint32_t channel_stride_x;
985 uint32_t channel_stride_y;
986 uint32_t channel_x_divisor;
987 uint32_t ne12;
988 uint32_t b_offset;
989 uint32_t d_offset;
990 uint32_t nb03;
991 uint32_t nb13;
992 uint32_t nb23;
993 uint32_t fusion_flags;
994};
995
996struct vk_mat_mat_id_push_constants {
997 uint32_t M; uint32_t N; uint32_t K;
998 uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
999 uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
1000 uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
1001 uint32_t padded_N;
1002};
1003struct vk_mat_vec_id_push_constants {
1004 uint32_t ncols;
1005 uint32_t stride_a;
1006 uint32_t stride_b;
1007 uint32_t stride_d;
1008 uint32_t batch_stride_a;
1009 uint32_t batch_stride_b;
1010 uint32_t batch_stride_d;
1011 uint32_t fusion_flags;
1012 uint32_t nei0;
1013 uint32_t ne11;
1014 uint32_t expert_i1;
1015 uint32_t nbi1;
1016};
1017
1018struct vk_flash_attn_push_constants {
1019 uint32_t N;
1020 uint32_t KV;
1021
1022 uint32_t ne1;
1023 uint32_t ne2;
1024 uint32_t ne3;
1025
1026 uint32_t neq2;
1027 uint32_t neq3;
1028 uint32_t nek2;
1029 uint32_t nek3;
1030 uint32_t nev2;
1031 uint32_t nev3;
1032 uint32_t nem1;
1033 uint32_t nem2;
1034 uint32_t nem3;
1035
1036 uint32_t nb01;
1037 uint32_t nb02;
1038 uint32_t nb03;
1039 uint32_t nb11;
1040 uint32_t nb12;
1041 uint32_t nb13;
1042 uint32_t nb21;
1043 uint32_t nb22;
1044 uint32_t nb23;
1045
1046 float scale;
1047 float max_bias;
1048 float logit_softcap;
1049
1050 uint32_t mask_n_head_log2;
1051 float m0;
1052 float m1;
1053
1054 uint32_t gqa_ratio;
1055 uint32_t split_kv;
1056 uint32_t k_num;
1057};
1058static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
1059
1060struct vk_op_push_constants {
1061 uint32_t KX;
1062 uint32_t KY;
1063 float param1;
1064 float param2;
1065 float param3;
1066 float param4;
1067};
1068
1069struct vk_op_count_experts_push_constants {
1070 uint32_t ne00;
1071 uint32_t ne01;
1072 uint32_t nb00;
1073 uint32_t nb01;
1074 uint32_t a_offset;
1075};
1076
1077struct vk_op_glu_push_constants {
1078 uint32_t N;
1079 uint32_t ne00;
1080 uint32_t ne20;
1081 uint32_t mode; // 0: default, 1: swapped, 2: split
1082 float alpha; // for swiglu_oai
1083 float limit;
1084};
1085
1086struct vk_op_unary_push_constants {
1087 uint32_t ne;
1088 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
1089 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
1090 uint32_t misalign_offsets;
1091 float param1; float param2;
1092 uint32_t ne0_012mp; uint32_t ne0_012L;
1093 uint32_t ne0_01mp; uint32_t ne0_01L;
1094 uint32_t ne0_0mp; uint32_t ne0_0L;
1095 uint32_t ne1_012mp; uint32_t ne1_012L;
1096 uint32_t ne1_01mp; uint32_t ne1_01L;
1097 uint32_t ne1_0mp; uint32_t ne1_0L;
1098};
1099static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
1100
1101static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) {
1102 GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst)));
1103 ne = ne != 0 ? ne : ggml_nelements(dst);
1104 GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
1105
1106 vk_op_unary_push_constants p{};
1107 p.ne = (uint32_t)ne;
1108
1109 size_t src0_tsize = ggml_type_size(src0->type);
1110 p.ne00 = (uint32_t)src0->ne[0];
1111 p.ne01 = (uint32_t)src0->ne[1];
1112 p.ne02 = (uint32_t)src0->ne[2];
1113 p.ne03 = (uint32_t)src0->ne[3];
1114 p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
1115 p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
1116 p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
1117 p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
1118
1119 size_t dst_tsize = ggml_type_size(dst->type);
1120 p.ne10 = (uint32_t)dst->ne[0];
1121 p.ne11 = (uint32_t)dst->ne[1];
1122 p.ne12 = (uint32_t)dst->ne[2];
1123 p.ne13 = (uint32_t)dst->ne[3];
1124 p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
1125 p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
1126 p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
1127 p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
1128
1129 return p; // offsets are initialized later in ggml_vk_op
1130}
1131
1132struct vk_op_pad_push_constants {
1133 uint32_t ne;
1134 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
1135 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
1136 uint32_t misalign_offsets;
1137 uint32_t circular;
1138
1139 uint32_t lp0; uint32_t rp0;
1140 uint32_t lp1; uint32_t rp1;
1141 uint32_t lp2; uint32_t rp2;
1142 uint32_t lp3; uint32_t rp3;
1143};
1144
1145static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
1146 int64_t ne = ggml_nelements(dst);
1147 GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
1148
1149 vk_op_pad_push_constants p{};
1150 p.ne = (uint32_t)ne;
1151
1152 size_t src0_tsize = ggml_type_size(src0->type);
1153 p.ne00 = (uint32_t)src0->ne[0];
1154 p.ne01 = (uint32_t)src0->ne[1];
1155 p.ne02 = (uint32_t)src0->ne[2];
1156 p.ne03 = (uint32_t)src0->ne[3];
1157 p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
1158 p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
1159 p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
1160 p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
1161
1162 size_t dst_tsize = ggml_type_size(dst->type);
1163 p.ne10 = (uint32_t)dst->ne[0];
1164 p.ne11 = (uint32_t)dst->ne[1];
1165 p.ne12 = (uint32_t)dst->ne[2];
1166 p.ne13 = (uint32_t)dst->ne[3];
1167 p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
1168 p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
1169 p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
1170 p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
1171
1172 p.lp0 = dst->op_params[0];
1173 p.rp0 = dst->op_params[1];
1174 p.lp1 = dst->op_params[2];
1175 p.rp1 = dst->op_params[3];
1176 p.lp2 = dst->op_params[4];
1177 p.rp2 = dst->op_params[5];
1178 p.lp3 = dst->op_params[6];
1179 p.rp3 = dst->op_params[7];
1180 p.circular = dst->op_params[8];
1181
1182 return p; // fastdiv values and offsets are initialized later in ggml_vk_op
1183}
1184
1185// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
1186// Precompute mp (m' in the paper) and L such that division
1187// can be computed using a multiply (high 32b of 64b result)
1188// and a shift:
1189//
1190// n/d = (mulhi(n, mp) + n) >> L;
1191static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
1192{
1193 // compute L = ceil(log2(d));
1194 L = 0;
1195 while (L < 32 && (uint32_t{1} << L) < d) {
1196 L++;
1197 }
1198
1199 mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
1200}
1201
1202template <typename T> void init_pushconst_fastdiv(T &p) {
1203 GGML_UNUSED(p);
1204 static_assert(!std::is_const<T>::value, "unexpected type");
1205}
1206
1207template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
1208 // Compute magic values to divide by these six numbers.
1209 init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L);
1210 init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L);
1211 init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L);
1212 init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L);
1213 init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L);
1214 init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L);
1215}
1216
1217struct vk_op_binary_push_constants {
1218 uint32_t ne;
1219 uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
1220 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
1221 uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
1222 uint32_t misalign_offsets;
1223 float param1; float param2; int32_t param3;
1224};
1225
1226struct vk_op_multi_add_push_constants {
1227 // shape for dst
1228 uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23;
1229
1230 // strides for srcs+dst
1231 uint32_t nb[MAX_PARAMETER_COUNT][4];
1232
1233 uint32_t rms_partials;
1234};
1235// update multi_add.comp if this changes
1236static_assert(MAX_PARAMETER_COUNT == 12);
1237static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
1238
1239struct vk_op_topk_moe_push_constants {
1240 uint32_t n_rows;
1241 uint32_t n_experts_push;
1242 uint32_t n_expert_used;
1243 float clamp_min;
1244 float clamp_max;
1245 uint32_t gating_func;
1246 uint32_t has_bias;
1247 uint32_t with_norm;
1248 float output_scale;
1249 float output_bias;
1250};
1251
1252struct vk_op_add_id_push_constants {
1253 uint32_t ne0;
1254 uint32_t ne1;
1255 uint32_t s01;
1256 uint32_t s02;
1257 uint32_t s11;
1258 uint32_t s21;
1259};
1260
1261struct vk_op_diag_mask_push_constants {
1262 uint32_t ncols;
1263 uint32_t rows_per_channel;
1264 int32_t n_past;
1265};
1266
1267struct vk_op_rope_push_constants {
1268 uint32_t rope_mode;
1269 uint32_t nrows;
1270 uint32_t n_dims;
1271 float freq_scale;
1272 float freq_base;
1273 float ext_factor;
1274 float attn_factor;
1275 float corr_dims[2];
1276 float theta_scale;
1277 uint32_t has_ff;
1278 int32_t sections[4];
1279 uint32_t is_imrope;
1280 uint32_t is_back;
1281 uint32_t set_rows_stride;
1282 uint32_t ne00;
1283 uint32_t ne01;
1284 uint32_t ne02;
1285 uint32_t nb01;
1286 uint32_t nb02;
1287 uint32_t nb03;
1288 uint32_t nb11;
1289 uint32_t nb12;
1290 uint32_t nb13;
1291};
1292static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
1293
1294// For fused rms_norm+mul+rope(+view+set_rows)
1295struct vk_op_rms_norm_mul_rope_push_constants {
1296 vk_op_binary_push_constants bin;
1297 vk_op_rope_push_constants rope;
1298};
1299
1300struct vk_op_soft_max_push_constants {
1301 uint32_t KX;
1302 uint32_t KY;
1303 uint32_t ne00;
1304 uint32_t ne01;
1305 uint32_t ne02;
1306 uint32_t ne12;
1307 uint32_t ne13;
1308 uint32_t nb11;
1309 uint32_t nb12;
1310 uint32_t nb13;
1311 float scale;
1312 float max_bias;
1313 float m0;
1314 float m1;
1315 uint32_t n_head_log2;
1316 uint32_t nrows_x;
1317 uint32_t has_sinks;
1318};
1319
1320struct vk_op_argsort_push_constants {
1321 uint32_t ncols;
1322 uint32_t ncols_padded;
1323 uint32_t ncols_padded_log2;
1324 uint32_t nrows;
1325 uint32_t order;
1326 uint32_t outer_start;
1327 uint32_t outer_end;
1328 uint32_t inner_start;
1329 uint32_t inner_end;
1330};
1331
1332struct vk_op_topk_push_constants {
1333 uint32_t orig_ncols;
1334 uint32_t ncols_input;
1335 uint32_t ncols_output;
1336 uint32_t k;
1337 uint32_t nrows;
1338 uint32_t first_pass;
1339 uint32_t last_pass;
1340};
1341
1342struct vk_op_im2col_push_constants {
1343 uint64_t dst_addr;
1344 uint32_t batch_offset; uint32_t offset_delta;
1345 uint32_t IC;
1346 uint32_t IW; uint32_t IH;
1347 uint32_t OW; uint32_t OH;
1348 uint32_t KW; uint32_t KH;
1349 uint32_t pelements;
1350 uint32_t CHW;
1351 int32_t s0; int32_t s1;
1352 int32_t p0; int32_t p1;
1353 int32_t d0; int32_t d1;
1354 uint32_t batch_IC;
1355};
1356
1357struct vk_op_im2col_3d_push_constants {
1358 uint64_t dst_addr;
1359 uint32_t nb10;
1360 uint32_t nb11;
1361 uint32_t nb12;
1362 uint32_t nb13;
1363 uint32_t s0;
1364 uint32_t s1;
1365 uint32_t s2;
1366 uint32_t p0;
1367 uint32_t p1;
1368 uint32_t p2;
1369 uint32_t d0;
1370 uint32_t d1;
1371 uint32_t d2;
1372 uint32_t IW;
1373 uint32_t IH;
1374 uint32_t ID;
1375 uint32_t IC;
1376 uint32_t KW;
1377 uint32_t OH;
1378 uint32_t KD_KH_KW;
1379 uint32_t KH_KW;
1380 uint32_t IC_KD_KH_KW;
1381 uint32_t N_OD_OH;
1382 uint32_t OD_OH;
1383 uint32_t OD_OH_OW_IC_KD_KH_KW;
1384 uint32_t OH_OW_IC_KD_KH_KW;
1385 uint32_t OW_IC_KD_KH_KW;
1386 uint32_t misalign_offsets;
1387};
1388
1389struct vk_op_timestep_embedding_push_constants {
1390 uint32_t nb1;
1391 uint32_t dim;
1392 uint32_t max_period;
1393};
1394
1395struct vk_op_conv_transpose_1d_push_constants {
1396 uint32_t Cout;
1397 uint32_t Cin;
1398 uint32_t K;
1399 uint32_t L;
1400 uint32_t KL;
1401
1402 uint32_t nb01;
1403 uint32_t nb02;
1404 uint32_t nb11;
1405 uint32_t nb1;
1406
1407 int32_t s0;
1408};
1409
1410struct vk_op_pool2d_push_constants {
1411 uint32_t IW; uint32_t IH;
1412 uint32_t OW; uint32_t OH;
1413 uint32_t OC;
1414 uint32_t pelements;
1415 uint32_t op;
1416 int32_t k0; int32_t k1;
1417 int32_t s0; int32_t s1;
1418 int32_t p0; int32_t p1;
1419};
1420
1421struct vk_op_rwkv_wkv6_push_constants {
1422 uint32_t B;
1423 uint32_t T;
1424 uint32_t C;
1425 uint32_t H;
1426};
1427
1428struct vk_op_rwkv_wkv7_push_constants {
1429 uint32_t B;
1430 uint32_t T;
1431 uint32_t C;
1432 uint32_t H;
1433};
1434struct vk_op_ssm_scan_push_constants {
1435 uint32_t nb02, nb03, nb12, nb13;
1436 uint32_t nb21, nb22, nb31;
1437 uint32_t nb42, nb43, nb52, nb53;
1438 uint32_t s_off;
1439 uint32_t n_head, d_head, n_group, n_tok;
1440};
1441struct vk_op_ssm_conv_push_constants {
1442 uint32_t nb01, nb02;
1443 uint32_t nb11;
1444 uint32_t dst_nb0, dst_nb1, dst_nb2;
1445 uint32_t nc, ncs, nr, n_t, n_s;
1446};
1447
1448struct vk_op_conv2d_push_constants {
1449 uint32_t Cout;
1450 uint32_t Cin;
1451 uint32_t N;
1452
1453 uint32_t W;
1454 uint32_t H;
1455 uint32_t OW;
1456 uint32_t OH;
1457
1458 uint32_t nb01;
1459 uint32_t nb02;
1460 uint32_t nb03;
1461
1462 uint32_t nb11;
1463 uint32_t nb12;
1464 uint32_t nb13;
1465
1466 uint32_t nb1;
1467 uint32_t nb2;
1468 uint32_t nb3;
1469
1470 // init_fastdiv_values constants for dividing by OW, OW*OH
1471 uint32_t OWmp; uint32_t OWL;
1472 uint32_t OWOHmp; uint32_t OWOHL;
1473};
1474
1475template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) {
1476 // Compute magic values to divide by OW, OW*OH
1477 init_fastdiv_values(p.OW, p.OWmp, p.OWL);
1478 init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL);
1479}
1480
1481struct vk_op_conv2d_dw_push_constants {
1482 uint32_t ne;
1483 uint32_t batches;
1484 uint32_t channels;
1485 uint32_t dst_w;
1486 uint32_t dst_h;
1487 uint32_t src_w;
1488 uint32_t src_h;
1489 uint32_t knl_w;
1490 uint32_t knl_h;
1491 int32_t stride_x;
1492 int32_t stride_y;
1493 int32_t pad_x;
1494 int32_t pad_y;
1495 int32_t dilation_x;
1496 int32_t dilation_y;
1497};
1498
1499struct vk_op_upscale_push_constants {
1500 uint32_t ne; uint32_t a_offset; uint32_t d_offset;
1501 uint32_t ne00; uint32_t ne01;
1502 uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
1503 uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
1504 float sf0; float sf1; float sf2; float sf3;
1505 float pixel_offset;
1506};
1507
1508struct vk_op_sum_rows_push_constants
1509{
1510 uint32_t n_cols;
1511 uint32_t ne01, ne02;
1512 uint32_t nb01, nb02, nb03;
1513 uint32_t nb11, nb12, nb13;
1514 float weight;
1515 uint32_t misalign_offsets;
1516 uint32_t ne0_12mp, ne0_12L;
1517 uint32_t ne0_1mp, ne0_1L;
1518};
1519
1520static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) {
1521 uint32_t type_size = (uint32_t)ggml_type_size(src->type);
1522 vk_op_sum_rows_push_constants p = {};
1523 p.n_cols = (uint32_t)n_cols;
1524 p.ne01 = (uint32_t)src->ne[1];
1525 p.ne02 = (uint32_t)src->ne[2];
1526 p.nb01 = (uint32_t)src->nb[1] / type_size;
1527 p.nb02 = (uint32_t)src->nb[2] / type_size;
1528 p.nb03 = (uint32_t)src->nb[3] / type_size;
1529 p.nb11 = (uint32_t)dst->nb[1] / type_size;
1530 p.nb12 = (uint32_t)dst->nb[2] / type_size;
1531 p.nb13 = (uint32_t)dst->nb[3] / type_size;
1532 p.weight = 1.0f;
1533 return p;
1534}
1535
1536template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) {
1537 init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L);
1538 init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L);
1539}
1540
1541struct vk_quantize_q8_1_push_constants {
1542 uint32_t ne;
1543 uint32_t num_blocks;
1544};
1545
1546struct vk_op_flash_attn_split_k_reduce_push_constants {
1547 uint32_t D;
1548 uint32_t ne1;
1549 uint32_t ne2;
1550 uint32_t ne3;
1551 uint32_t k_num;
1552 uint32_t sinks;
1553};
1554
1555struct vk_op_flash_attn_mask_opt_push_constants {
1556 uint32_t nem0;
1557 uint32_t nem1;
1558 uint32_t nem2;
1559 uint32_t nbm1;
1560 uint32_t nbm2;
1561 uint32_t nbm3;
1562 uint32_t nbd1;
1563 uint32_t nbd2;
1564 uint32_t nbd3;
1565};
1566
1567// Allow pre-recording command buffers
1568struct vk_staging_memcpy {
1569 vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
1570
1571 void * dst;
1572 const void * src;
1573 size_t n;
1574};
1575
1576struct vk_staging_memset {
1577 vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
1578
1579 void * dst;
1580 uint32_t val;
1581 size_t n;
1582};
1583
1584struct vk_context_struct {
1585 vk_submission * s;
1586 std::vector<vk_sequence> seqs;
1587
1588 int exit_tensor_idx;
1589
1590 std::vector<vk_staging_memcpy> in_memcpys;
1591 std::vector<vk_staging_memcpy> out_memcpys;
1592 std::vector<vk_staging_memset> memsets;
1593
1594 vk_command_pool * p {};
1595};
1596typedef std::shared_ptr<vk_context_struct> vk_context;
1597typedef std::weak_ptr<vk_context_struct> vk_context_ref;
1598
1599struct ggml_vk_garbage_collector {
1600 std::vector<vk_semaphore> tl_semaphores;
1601 std::vector<vk_semaphore> semaphores;
1602 std::vector<vk::Event> events;
1603 std::vector<vk_context> contexts;
1604};
1605
1606static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
1607static void ggml_vk_load_shaders(vk_device& device);
1608static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
1609
1610static bool vk_memory_logger_enabled = false;
1611
1612#define VK_LOG_MEMORY(msg) if (vk_memory_logger_enabled) { std::cerr << "ggml_vulkan memory: " << msg << std::endl; }
1613
1614static std::string format_size(size_t size) {
1615 const size_t kib = 1024;
1616 const size_t mib = kib * 1024;
1617 const size_t gib = mib * 1024;
1618
1619 std::ostringstream oss;
1620 oss << std::fixed << std::setprecision(2);
1621
1622 if (size >= gib) {
1623 oss << static_cast<double>(size) / gib << " GiB";
1624 } else if (size >= mib) {
1625 oss << static_cast<double>(size) / mib << " MiB";
1626 } else if (size >= kib) {
1627 oss << static_cast<double>(size) / kib << " KiB";
1628 } else {
1629 oss << size << " B";
1630 }
1631
1632 return oss.str();
1633}
1634
1635class vk_memory_logger {
1636public:
1637 vk_memory_logger(): total_device(0), total_host(0) {}
1638 void log_allocation(vk_buffer_ref buf_ref, size_t size);
1639 void log_deallocation(vk_buffer_ref buf_ref);
1640
1641private:
1642 std::map<vk::Buffer, size_t> allocations; // Track allocations
1643 size_t total_device;
1644 size_t total_host;
1645 static std::mutex log_mutex;
1646};
1647
1648std::mutex vk_memory_logger::log_mutex;
1649
1650static bool vk_perf_logger_enabled = false;
1651static bool vk_perf_logger_concurrent = false;
1652static bool vk_enable_sync_logger = false;
1653// number of calls between perf logger prints
1654static uint32_t vk_perf_logger_frequency = 1;
1655
1656class vk_perf_logger {
1657 public:
1658 void print_timings(bool force = false) {
1659 if (timings.empty()) {
1660 return;
1661 }
1662 print_count++;
1663 if ((print_count % vk_perf_logger_frequency) != 0 && !force) {
1664 return;
1665 }
1666 print_count = 0;
1667 uint64_t total_all_op_times = 0;
1668 std::cerr << "----------------\nVulkan Timings:" << std::endl;
1669 for (const auto & t : timings) {
1670 uint64_t total_op_times = 0;
1671 for (const auto & time : t.second) {
1672 total_op_times += time;
1673 }
1674 std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0)
1675 << " us = " << (total_op_times / 1000.0) << " us";
1676
1677 // If we have as many flops entries as timing entries for the op, then compute and log the flops/S.
1678 auto it = flops.find(t.first);
1679 if (it != flops.end() && (it->second).size() == t.second.size()) {
1680 uint64_t total_op_flops = 0;
1681 for (const auto & elem : it->second) {
1682 total_op_flops += elem;
1683 }
1684 std::cerr << " ("
1685 << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) /
1686 (double(total_op_times) / (1000.0 * 1000.0 * 1000.0))
1687 << " GFLOPS/s)";
1688 }
1689
1690 total_all_op_times += total_op_times;
1691
1692 std::cerr << std::endl;
1693 }
1694
1695 if (timings.size() > 0) {
1696 std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl;
1697 }
1698
1699 timings.clear();
1700 flops.clear();
1701 }
1702
1703 std::string get_node_fusion_name(const ggml_tensor * node, const char *fusion_name, uint64_t *n_flops) {
1704 *n_flops = 0;
1705 std::string fusion_str;
1706 if (fusion_name) {
1707 fusion_str = fusion_name + std::string(" ");
1708 }
1709 if (node->op == GGML_OP_UNARY) {
1710 return fusion_str + ggml_unary_op_name(ggml_get_unary_op(node));
1711 }
1712 if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
1713 const uint64_t m = node->ne[0];
1714 const uint64_t n = node->ne[1];
1715 const uint64_t k = node->src[1]->ne[0];
1716 const uint64_t batch = node->ne[2] * node->ne[3];
1717 std::string name = ggml_op_name(node->op);
1718 if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) ||
1719 (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) {
1720 name += "_VEC";
1721 }
1722 name += " ";
1723 name += ggml_type_name(node->src[0]->type);
1724 name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k);
1725 if (node->op == GGML_OP_MUL_MAT_ID) {
1726 name += " n_expert=" + std::to_string(node->src[0]->ne[2]);
1727 }
1728 if (batch > 1) {
1729 name += " batch=" + std::to_string(batch);
1730 }
1731 name = fusion_str + name;
1732 *n_flops = m * n * (k + (k - 1)) * batch;
1733 return name;
1734 }
1735 if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) {
1736 std::string name = ggml_op_name(node->op);
1737 ggml_tensor * knl = node->src[0];
1738 uint64_t OW = node->ne[0];
1739 uint64_t OH = node->ne[1];
1740 uint64_t N = node->ne[3];
1741 uint64_t Cout = node->ne[2];
1742 uint64_t KW = knl->ne[0];
1743 uint64_t KH = knl->ne[1];
1744 uint64_t Cin = node->src[1]->ne[2];
1745 // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ
1746 uint64_t size_M = Cout;
1747 uint64_t size_K = Cin * KW * KH;
1748 uint64_t size_N = N * OW * OH;
1749 *n_flops = size_M * size_N * (size_K + (size_K - 1));
1750 name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) +
1751 ", N=N*OW*OH=" + std::to_string(size_N);
1752 name = fusion_str + name;
1753 return name;
1754 }
1755 if (node->op == GGML_OP_RMS_NORM) {
1756 std::string name = ggml_op_name(node->op);
1757 name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")";
1758 name = fusion_str + name;
1759 return name;
1760 }
1761 if (node->op == GGML_OP_FLASH_ATTN_EXT) {
1762 const ggml_tensor * dst = node;
1763 const ggml_tensor * q = node->src[0];
1764 const ggml_tensor * k = node->src[1];
1765 const ggml_tensor * v = node->src[2];
1766 const ggml_tensor * m = node->src[3];
1767 std::stringstream name;
1768 name << fusion_str;
1769 name << ggml_op_name(node->op) <<
1770 " dst(" << dst->ne[0] << "," << dst->ne[1] << "," << dst->ne[2] << "," << dst->ne[3] << "), " <<
1771 " q(" << q->ne[0] << "," << q->ne[1] << "," << q->ne[2] << "," << q->ne[3] << "), " <<
1772 " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
1773 " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
1774 " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
1775 *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
1776 return name.str();
1777 }
1778 if (node->op == GGML_OP_TOP_K) {
1779 std::stringstream name;
1780 name << fusion_str;
1781 name << ggml_op_name(node->op) <<
1782 " K=" << node->ne[0] <<
1783 " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
1784 return name.str();
1785 }
1786 return fusion_str + ggml_op_name(node->op);
1787 }
1788
1789 void log_timing(const ggml_tensor * node, const char *fusion_name, uint64_t time) {
1790 uint64_t n_flops;
1791 std::string name = get_node_fusion_name(node, fusion_name, &n_flops);
1792 if (n_flops) {
1793 flops[name].push_back(n_flops);
1794 }
1795 timings[name].push_back(time);
1796 }
1797
1798 void log_timing(const std::vector<ggml_tensor *> &nodes, const std::vector<const char *> &names, uint64_t time) {
1799 uint64_t total_flops = 0;
1800 std::string name;
1801 for (size_t n = 0; n < nodes.size(); ++n) {
1802 uint64_t n_flops = 0;
1803 name += get_node_fusion_name(nodes[n], names[n], &n_flops);
1804 total_flops += n_flops;
1805
1806 if (n != nodes.size() - 1) {
1807 name += ", ";
1808 }
1809 }
1810 if (total_flops) {
1811 flops[name].push_back(total_flops);
1812 }
1813 timings[name].push_back(time);
1814 }
1815
1816 private:
1817 std::map<std::string, std::vector<uint64_t>> timings;
1818 std::map<std::string, std::vector<uint64_t>> flops;
1819 uint32_t print_count {};
1820};
1821
1822struct ggml_backend_vk_context {
1823 std::string name;
1824
1825 vk_device device;
1826
1827 size_t semaphore_idx, event_idx;
1828 ggml_vk_garbage_collector gc;
1829 size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset;
1830 vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials, sync_staging;
1831 vk::Fence fence, almost_ready_fence;
1832 bool submit_pending {};
1833 bool almost_ready_fence_pending {};
1834 // Set before op_add and unset after op_rms_norm to indicate that the add should
1835 // write partial sums to accumulate the square of the vector components
1836 bool do_add_rms_partials_offset_calculation;
1837 bool do_add_rms_partials;
1838
1839 uint64_t last_total_mul_mat_bytes {};
1840
1841 // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
1842 vk_pipeline_struct * prealloc_y_last_pipeline_used {};
1843 const ggml_tensor * prealloc_y_last_tensor_used {};
1844
1845 // Track which nodes have been used since the last sync, and whether they were written to
1846 std::vector<const ggml_tensor *> unsynced_nodes_written;
1847 std::vector<const ggml_tensor *> unsynced_nodes_read;
1848 // Track which prealloc buffers have pending reads that need to be synchronized.
1849 // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set),
1850 // and set to true after the buffer contents are consumed.
1851 bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
1852
1853 vk_context_ref compute_ctx;
1854
1855 std::vector<vk_context_ref> tensor_ctxs;
1856
1857 std::vector<vk::DescriptorPool> descriptor_pools;
1858 std::vector<vk::DescriptorSet> descriptor_sets;
1859 uint32_t descriptor_set_idx {};
1860 uint32_t pipeline_descriptor_set_requirements {};
1861
1862 vk_command_pool compute_cmd_pool;
1863
1864 // number of additional consecutive nodes that are being fused with the
1865 // node currently being processed
1866 int num_additional_fused_ops {};
1867 // Bitmask of which fused ops need to write an intermediate value to memory.
1868 // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
1869 // If there's no fusion, bit 0 is still set.
1870 int fused_ops_write_mask {};
1871 topk_moe_mode fused_topk_moe_mode {};
1872 bool fused_topk_moe_scale {};
1873
1874 // for GGML_VK_PERF_LOGGER
1875 std::unique_ptr<vk_perf_logger> perf_logger;
1876 vk::QueryPool query_pool;
1877 std::vector<const char *> query_fusion_names;
1878 std::vector<int> query_fusion_node_count;
1879 std::vector<ggml_tensor *> query_nodes;
1880 std::vector<int> query_node_idx;
1881 int32_t num_queries {};
1882 int32_t query_idx {};
1883};
1884
1885static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
1886
1887static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
1888 if (tensor->view_src) {
1889 return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
1890 }
1891 return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
1892}
1893
1894static uint32_t get_misalign_bytes(const ggml_backend_vk_context * ctx, const ggml_tensor * t)
1895{
1896 return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
1897}
1898
1899template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
1900 GGML_UNUSED(p);
1901 GGML_UNUSED(src0);
1902 GGML_UNUSED(src1);
1903 GGML_UNUSED(src2);
1904 GGML_UNUSED(src3);
1905 GGML_UNUSED(dst);
1906 static_assert(!std::is_const<T>::value, "unexpected type");
1907 GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
1908 GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
1909 GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
1910 GGML_ASSERT(!src3 || get_misalign_bytes(ctx, src3) == 0);
1911 GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
1912}
1913
1914template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_p021_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
1915 const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
1916 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
1917
1918 p.b_offset = b_offset;
1919 p.d_offset = d_offset;
1920
1921 GGML_UNUSED(src0);
1922 GGML_UNUSED(src2);
1923 GGML_UNUSED(src3);
1924}
1925
1926template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_mat_vec_nc_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
1927 const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
1928 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
1929
1930 p.b_offset = b_offset;
1931 p.d_offset = d_offset;
1932
1933 GGML_UNUSED(src0);
1934 GGML_UNUSED(src2);
1935 GGML_UNUSED(src3);
1936}
1937
1938struct ggml_backend_vk_buffer_context {
1939 vk_device_ref device;
1940 vk_buffer dev_buffer;
1941 std::string name;
1942
1943 ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) :
1944 device(device),
1945 dev_buffer(dev_buffer),
1946 name(name) {
1947 }
1948
1949 ~ggml_backend_vk_buffer_context() {
1950 ggml_vk_destroy_buffer(dev_buffer);
1951 }
1952};
1953
1954void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) {
1955 if (!vk_memory_logger_enabled) {
1956 return;
1957 }
1958 std::lock_guard<std::mutex> guard(log_mutex);
1959 vk_buffer buf = buf_ref.lock();
1960 const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
1961 const std::string type = device ? "device" : "host";
1962 allocations[buf->buffer] = size;
1963 total_device += device ? size : 0;
1964 total_host += device ? 0 : size;
1965 VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
1966}
1967
1968void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
1969 if (buf_ref.expired() || buf_ref.lock()->size == 0 || !vk_memory_logger_enabled) {
1970 return;
1971 }
1972
1973 std::lock_guard<std::mutex> guard(log_mutex);
1974 vk_buffer buf = buf_ref.lock();
1975 const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
1976 std::string type = device ? "device" : "host";
1977 auto it = allocations.find(buf->buffer);
1978 total_device -= device ? it->second : 0;
1979 total_host -= device ? 0 : it->second;
1980 if (it != allocations.end()) {
1981 VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
1982 allocations.erase(it);
1983 } else {
1984 VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer);
1985 }
1986}
1987
1988struct vk_instance_t {
1989 vk::Instance instance;
1990
1991 bool debug_utils_support = false; // VK_EXT_debug_utils enabled
1992 PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {};
1993 PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {};
1994 PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {};
1995 PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {};
1996 PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {};
1997 PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
1998
1999 std::vector<size_t> device_indices;
2000 std::vector<bool> device_supports_membudget;
2001 vk_device devices[GGML_VK_MAX_DEVICES];
2002};
2003
2004static bool vk_instance_initialized = false;
2005static vk_instance_t vk_instance;
2006
2007#ifdef GGML_VULKAN_CHECK_RESULTS
2008static size_t vk_skip_checks;
2009static size_t vk_output_tensor;
2010
2011static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
2012static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
2013static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
2014#endif
2015
2016typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
2017
2018static void ggml_backend_vk_free(ggml_backend_t backend);
2019
2020static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) {
2021 const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset},
2022 VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange});
2023 return range;
2024}
2025
2026// Wait for ctx->fence to be signaled.
2027static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
2028 // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
2029 // during this wait.
2030 if (ctx->almost_ready_fence_pending) {
2031 VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
2032 ctx->device->device.resetFences({ ctx->almost_ready_fence });
2033 ctx->almost_ready_fence_pending = false;
2034 }
2035
2036 // Spin (w/pause) waiting for the graph to finish executing.
2037 vk::Result result;
2038 while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
2039 if (result != vk::Result::eNotReady) {
2040 fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
2041 exit(1);
2042 }
2043 for (uint32_t i = 0; i < 100; ++i) {
2044 YIELD();
2045 YIELD();
2046 YIELD();
2047 YIELD();
2048 YIELD();
2049 YIELD();
2050 YIELD();
2051 YIELD();
2052 YIELD();
2053 YIELD();
2054 }
2055 }
2056 ctx->device->device.resetFences({ ctx->fence });
2057}
2058
2059// variables to track number of compiles in progress
2060static uint32_t compile_count = 0;
2061static std::mutex compile_count_mutex;
2062static std::condition_variable compile_count_cond;
2063
2064static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
2065 uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
2066 bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
2067 VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count <<
2068 ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " <<
2069 disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
2070 GGML_ASSERT(parameter_count > 0);
2071 GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT);
2072 GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
2073
2074 vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
2075 pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
2076
2077 vk::PushConstantRange pcr(
2078 vk::ShaderStageFlagBits::eCompute,
2079 0,
2080 pipeline->push_constant_size
2081 );
2082
2083 vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr);
2084 pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info);
2085
2086 std::vector<vk::SpecializationMapEntry> specialization_entries(specialization_constants.size());
2087
2088 for (size_t i = 0; i < specialization_constants.size(); i++) {
2089 specialization_entries[i].constantID = i;
2090 specialization_entries[i].offset = i * sizeof(uint32_t);
2091 specialization_entries[i].size = sizeof(uint32_t);
2092 }
2093
2094 vk::SpecializationInfo specialization_info(
2095 specialization_entries.size(),
2096 specialization_entries.data(),
2097 specialization_constants.size() * sizeof(uint32_t),
2098 specialization_constants.data()
2099 );
2100
2101 vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
2102
2103 if (device->subgroup_require_full_support && require_full_subgroups) {
2104 pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
2105 }
2106
2107 vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
2108 pipeline_shader_stage_create_flags,
2109 vk::ShaderStageFlagBits::eCompute,
2110 pipeline->shader_module,
2111 entrypoint.c_str(),
2112 &specialization_info);
2113
2114 vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
2115 pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
2116 if (device->subgroup_size_control && required_subgroup_size > 0) {
2117 GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
2118 pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
2119 }
2120
2121 vk::ComputePipelineCreateInfo compute_pipeline_create_info(
2122 device->pipeline_executable_properties_support ?
2123 vk::PipelineCreateFlagBits::eCaptureStatisticsKHR :
2124 vk::PipelineCreateFlags{},
2125 pipeline_shader_create_info,
2126 pipeline->layout);
2127
2128 vk::PipelineRobustnessCreateInfoEXT rci;
2129
2130 if (device->pipeline_robustness && disable_robustness) {
2131 rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
2132 rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
2133 compute_pipeline_create_info.setPNext(&rci);
2134 }
2135
2136#if defined(VK_EXT_shader_64bit_indexing)
2137 vk::PipelineCreateFlags2CreateInfo pipelineFlags2CreateInfo;
2138 if (pipeline->is_64b_indexing)
2139 {
2140 pipelineFlags2CreateInfo.flags = vk::PipelineCreateFlagBits2::e64BitIndexingEXT;
2141 if (device->pipeline_executable_properties_support) {
2142 pipelineFlags2CreateInfo.flags |= vk::PipelineCreateFlagBits2::eCaptureStatisticsKHR;
2143 }
2144 pipelineFlags2CreateInfo.setPNext(compute_pipeline_create_info.pNext);
2145 compute_pipeline_create_info.setPNext(&pipelineFlags2CreateInfo);
2146 }
2147#endif
2148
2149 try {
2150 pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
2151 } catch (const vk::SystemError& e) {
2152 std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl;
2153 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
2154 throw e;
2155 }
2156 pipeline->compiled = true;
2157
2158 if (vk_instance.debug_utils_support) {
2159 vk::DebugUtilsObjectNameInfoEXT duoni;
2160 duoni.objectType = vk::ObjectType::ePipeline;
2161 duoni.pObjectName = pipeline->name.c_str();
2162 duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast<VkPipeline>(pipeline->pipeline));
2163 vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast<VkDebugUtilsObjectNameInfoEXT &>(duoni));
2164 }
2165
2166 if (device->pipeline_executable_properties_support) {
2167 vk::PipelineExecutableInfoKHR executableInfo;
2168 executableInfo.pipeline = pipeline->pipeline;
2169
2170 auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
2171 for (auto & s : statistics) {
2172 // "Register Count" is reported by NVIDIA drivers.
2173 if (strcmp(s.name, "Register Count") == 0) {
2174 VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
2175 pipeline->register_count = (uint32_t)s.value.u64;
2176 }
2177 }
2178 }
2179
2180 device->all_pipelines.push_back(pipeline);
2181
2182 {
2183 std::lock_guard<std::mutex> guard(compile_count_mutex);
2184 assert(compile_count > 0);
2185 compile_count--;
2186 }
2187 compile_count_cond.notify_all();
2188}
2189
2190static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
2191 VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")");
2192 device.destroyPipelineLayout(pipeline->layout);
2193
2194 device.destroyShaderModule(pipeline->shader_module);
2195
2196 device.destroyPipeline(pipeline->pipeline);
2197}
2198
2199static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) {
2200 VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
2201 ctx->pipeline_descriptor_set_requirements += n;
2202 if (!pipeline->compiled) {
2203 pipeline->needed = true;
2204 ggml_vk_load_shaders(ctx->device);
2205 }
2206 ggml_pipeline_allocate_descriptor_sets(ctx);
2207}
2208
2209static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) {
2210
2211 if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) {
2212 // Enough descriptors are available
2213 return;
2214 }
2215
2216 vk_device& device = ctx->device;
2217
2218 // Grow by 50% to avoid frequent allocations
2219 uint32_t needed = std::max(3 * ctx->descriptor_sets.size() / 2, size_t{ctx->pipeline_descriptor_set_requirements});
2220 uint32_t to_alloc = needed - ctx->descriptor_sets.size();
2221 uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE;
2222 uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE;
2223
2224 while (to_alloc > 0) {
2225 const uint32_t alloc_count = std::min(pool_remaining, to_alloc);
2226 to_alloc -= alloc_count;
2227 pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE;
2228
2229 if (pool_idx >= ctx->descriptor_pools.size()) {
2230 vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE);
2231 vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size);
2232 ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info));
2233 }
2234
2235 std::vector<vk::DescriptorSetLayout> layouts(alloc_count);
2236 for (uint32_t i = 0; i < alloc_count; i++) {
2237 layouts[i] = device->dsl;
2238 }
2239 vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data());
2240 std::vector<vk::DescriptorSet> sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info);
2241 ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end());
2242
2243 pool_idx++;
2244 }
2245}
2246
2247static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
2248 VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
2249
2250 if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
2251 // Reuse command buffer
2252 return p.cmd_buffers[p.cmd_buffer_idx++];
2253 }
2254
2255 vk::CommandBufferAllocateInfo command_buffer_alloc_info(
2256 p.pool,
2257 vk::CommandBufferLevel::ePrimary,
2258 1);
2259 const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
2260 auto buf = cmd_buffers.front();
2261
2262 p.cmd_buffers.push_back(buf);
2263 p.cmd_buffer_idx++;
2264
2265 return buf;
2266}
2267
2268static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
2269 if (ctx->seqs.empty()) {
2270 if (fence) {
2271 std::lock_guard<std::mutex> guard(queue_mutex);
2272 ctx->p->q->queue.submit({}, fence);
2273 }
2274 return;
2275 }
2276 VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")");
2277
2278 std::vector<std::vector<uint64_t>> tl_wait_vals;
2279 std::vector<std::vector<uint64_t>> tl_signal_vals;
2280 std::vector<std::vector<vk::Semaphore>> tl_wait_semaphores;
2281 std::vector<std::vector<vk::Semaphore>> tl_signal_semaphores;
2282 std::vector<vk::TimelineSemaphoreSubmitInfo> tl_submit_infos;
2283 std::vector<vk::SubmitInfo> submit_infos;
2284 int idx = -1;
2285 std::vector<std::vector<vk::PipelineStageFlags>> stage_flags;
2286
2287 size_t reserve = 0;
2288
2289 for (const auto& sequence : ctx->seqs) {
2290 reserve += sequence.size();
2291 }
2292
2293 // Pre-reserve vectors to prevent reallocation, which invalidates pointers
2294 tl_wait_semaphores.reserve(reserve);
2295 tl_wait_vals.reserve(reserve);
2296 tl_signal_semaphores.reserve(reserve);
2297 tl_signal_vals.reserve(reserve);
2298 tl_submit_infos.reserve(reserve);
2299 submit_infos.reserve(reserve);
2300 stage_flags.reserve(reserve);
2301
2302 for (const auto& sequence : ctx->seqs) {
2303 for (const auto& submission : sequence) {
2304 stage_flags.push_back({});
2305 idx++;
2306 tl_wait_vals.push_back({});
2307 tl_wait_semaphores.push_back({});
2308 tl_signal_vals.push_back({});
2309 tl_signal_semaphores.push_back({});
2310 for (size_t i = 0; i < submission.wait_semaphores.size(); i++) {
2311 stage_flags[idx].push_back(ctx->p->q->stage_flags);
2312 tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value);
2313 tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s);
2314 }
2315 for (size_t i = 0; i < submission.signal_semaphores.size(); i++) {
2316 tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value);
2317 tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s);
2318 }
2319 tl_submit_infos.push_back({
2320 (uint32_t) submission.wait_semaphores.size(),
2321 tl_wait_vals[idx].data(),
2322 (uint32_t) submission.signal_semaphores.size(),
2323 tl_signal_vals[idx].data(),
2324 });
2325 tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo;
2326 tl_submit_infos[idx].pNext = nullptr;
2327 vk::SubmitInfo si{
2328 (uint32_t) submission.wait_semaphores.size(),
2329 tl_wait_semaphores[idx].data(),
2330 stage_flags[idx].data(),
2331 1,
2332 &submission.buffer,
2333 (uint32_t) submission.signal_semaphores.size(),
2334 tl_signal_semaphores[idx].data(),
2335 };
2336 si.setPNext(&tl_submit_infos[idx]);
2337 submit_infos.push_back(si);
2338 }
2339 }
2340
2341 std::lock_guard<std::mutex> guard(queue_mutex);
2342 ctx->p->q->queue.submit(submit_infos, fence);
2343
2344 ctx->seqs.clear();
2345}
2346
2347static uint32_t ggml_vk_find_queue_family_index(std::vector<vk::QueueFamilyProperties>& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) {
2348 VK_LOG_DEBUG("ggml_vk_find_queue_family_index()");
2349 const uint32_t qfsize = queue_family_props.size();
2350
2351 // Try with avoid preferences first
2352 for (uint32_t i = 0; i < qfsize; i++) {
2353 if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) {
2354 return i;
2355 }
2356 }
2357
2358 // Fall back to only required
2359 for (size_t i = 0; i < qfsize; i++) {
2360 if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) {
2361 return i;
2362 }
2363 }
2364
2365 // Fall back to reusing compute queue
2366 for (size_t i = 0; i < qfsize; i++) {
2367 if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) {
2368 return i;
2369 }
2370 }
2371
2372 // Fall back to ignoring min_num_queries
2373 for (size_t i = 0; i < qfsize; i++) {
2374 if (queue_family_props[i].queueFlags & required) {
2375 return i;
2376 }
2377 }
2378
2379 // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations.
2380 // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional.
2381 if (compute_index >= 0) {
2382 return compute_index;
2383 }
2384
2385 std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl;
2386
2387 for(auto &q_family : queue_family_props) {
2388 std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl;
2389 }
2390 abort();
2391}
2392
2393static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) {
2394 VK_LOG_DEBUG("ggml_vk_create_queue()");
2395 std::lock_guard<std::recursive_mutex> guard(device->mutex);
2396
2397 q.queue_family_index = queue_family_index;
2398 q.transfer_only = transfer_only;
2399
2400 q.cmd_pool.init(device, &q);
2401
2402 q.queue = device->device.getQueue(queue_family_index, queue_index);
2403
2404 q.stage_flags = stage_flags;
2405}
2406
2407static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) {
2408 vk_context result = std::make_shared<vk_context_struct>();
2409 VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")");
2410 ctx->gc.contexts.emplace_back(result);
2411 result->p = &p;
2412 return result;
2413}
2414
2415static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) {
2416 vk_context result = std::make_shared<vk_context_struct>();
2417 VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")");
2418 result->p = &p;
2419 return result;
2420}
2421
2422static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) {
2423 VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
2424 vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 };
2425 vk::SemaphoreCreateInfo ci{};
2426 ci.setPNext(&tci);
2427 vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
2428 ctx->gc.semaphores.push_back({ semaphore, 0 });
2429 return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1];
2430}
2431
2432static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) {
2433 VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()");
2434 if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) {
2435 vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
2436 vk::SemaphoreCreateInfo ci{};
2437 ci.setPNext(&tci);
2438 vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci);
2439 ctx->gc.tl_semaphores.push_back({ semaphore, 0 });
2440 }
2441 return &ctx->gc.tl_semaphores[ctx->semaphore_idx++];
2442}
2443
2444static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) {
2445 if (ctx->event_idx >= ctx->gc.events.size()) {
2446 ctx->gc.events.push_back(ctx->device->device.createEvent({}));
2447 }
2448 return ctx->gc.events[ctx->event_idx++];
2449}
2450
2451static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) {
2452 VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()");
2453
2454 // Requires command buffers to be done
2455 device->device.resetCommandPool(p.pool);
2456 p.cmd_buffer_idx = 0;
2457}
2458
2459static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
2460 VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()");
2461
2462 // Arbitrary frequency to cleanup/reuse command buffers
2463 static constexpr uint32_t cleanup_frequency = 10;
2464
2465 if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
2466 ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
2467 }
2468 if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
2469 ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
2470 }
2471}
2472
2473static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
2474 std::vector<uint32_t> indices;
2475
2476 for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) {
2477 vk::MemoryType memory_type = mem_props->memoryTypes[i];
2478 if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) &&
2479 (flags & memory_type.propertyFlags) == flags &&
2480 mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) {
2481 indices.push_back(i);
2482 }
2483 }
2484 return indices;
2485}
2486
2487static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list,
2488 void *import_ptr = nullptr) {
2489 VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
2490 if (size > device->max_buffer_size) {
2491 throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
2492 }
2493
2494 vk_buffer buf = std::make_shared<vk_buffer_struct>();
2495
2496 if (size == 0) {
2497 buf->size = 0;
2498 return buf;
2499 }
2500
2501 vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
2502 vk::MemoryAllocateFlags mem_flags {};
2503 if (device->buffer_device_address) {
2504 usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
2505 mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
2506 }
2507
2508 vk::BufferCreateInfo buffer_create_info{
2509 vk::BufferCreateFlags(),
2510 size,
2511 usage_flags,
2512 vk::SharingMode::eExclusive,
2513 0,
2514 nullptr,
2515 };
2516
2517 vk::ExternalMemoryBufferCreateInfo external_memory_bci;
2518 if (import_ptr) {
2519 external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
2520 buffer_create_info.setPNext(&external_memory_bci);
2521 }
2522
2523 buf->buffer = device->device.createBuffer(buffer_create_info);
2524
2525 vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
2526
2527 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
2528
2529 const vk::MemoryPriorityAllocateInfoEXT mem_priority_info { 1.0f };
2530
2531 vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
2532
2533 if (device->memory_priority) {
2534 mem_flags_info.setPNext(&mem_priority_info);
2535 }
2536
2537 if (import_ptr) {
2538 vk::MemoryHostPointerPropertiesEXT host_pointer_props;
2539 try {
2540 host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr);
2541 } catch (vk::SystemError& e) {
2542 GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what());
2543 device->device.destroyBuffer(buf->buffer);
2544 return {};
2545 }
2546 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
2547
2548 uint32_t memory_type_idx;
2549 vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
2550 for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
2551 if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
2552 continue;
2553 }
2554 if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
2555 continue;
2556 }
2557
2558 vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx];
2559 // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed
2560 if ((memory_type.propertyFlags & property_flags) == property_flags) {
2561 property_flags = memory_type.propertyFlags;
2562 break;
2563 }
2564 }
2565 if (memory_type_idx == 32) {
2566 GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n");
2567 device->device.destroyBuffer(buf->buffer);
2568 return {};
2569 }
2570
2571 buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
2572 try {
2573 vk::ImportMemoryHostPointerInfoEXT import_info;
2574 import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
2575 import_info.pHostPointer = import_ptr;
2576 import_info.setPNext(&mem_flags_info);
2577 buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
2578 } catch (const vk::SystemError& e) {
2579 }
2580 } else {
2581 for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
2582 const auto & req_flags = *it;
2583
2584 const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
2585
2586 if (memory_type_indices.empty()) {
2587 continue;
2588 }
2589 buf->memory_property_flags = req_flags;
2590
2591 bool done = false;
2592
2593 for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
2594 try {
2595 buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
2596 done = true;
2597 break;
2598 } catch (const vk::SystemError& e) {
2599 // loop and retry
2600 // during last attempt throw the exception
2601 if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
2602 device->device.destroyBuffer(buf->buffer);
2603 throw e;
2604 }
2605 }
2606 }
2607
2608 if (done) {
2609 break;
2610 }
2611 }
2612 }
2613
2614 if (!buf->device_memory) {
2615 device->device.destroyBuffer(buf->buffer);
2616 throw vk::OutOfDeviceMemoryError("No suitable memory type found");
2617 }
2618
2619 buf->ptr = nullptr;
2620
2621 if (import_ptr) {
2622 buf->ptr = import_ptr;
2623 } else {
2624 if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
2625 buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
2626 }
2627 }
2628
2629 device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
2630
2631 buf->device = device;
2632 buf->size = size;
2633
2634 if (device->buffer_device_address) {
2635 const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2636 buf->bda_addr = device->device.getBufferAddress(addressInfo);
2637 }
2638
2639 device->memory_logger->log_allocation(buf, size);
2640
2641 return buf;
2642}
2643
2644static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
2645 try {
2646 return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags});
2647 } catch (const vk::SystemError& e) {
2648 std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
2649 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
2650 throw e;
2651 }
2652}
2653
2654static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
2655 vk_buffer buf;
2656 try {
2657 if (device->prefer_host_memory) {
2658 buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
2659 vk::MemoryPropertyFlagBits::eDeviceLocal});
2660 } else if (device->uma) {
2661 // Fall back to host memory type
2662 buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
2663 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
2664 } else if (device->disable_host_visible_vidmem) {
2665 if (device->allow_sysmem_fallback) {
2666 buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
2667 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
2668 } else {
2669 buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal});
2670 }
2671 } else {
2672 // use rebar if available, otherwise fallback to device only visible memory
2673 if (device->allow_sysmem_fallback) {
2674 buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
2675 vk::MemoryPropertyFlagBits::eDeviceLocal,
2676 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
2677 } else {
2678 buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
2679 vk::MemoryPropertyFlagBits::eDeviceLocal});
2680 }
2681 }
2682 } catch (const vk::SystemError& e) {
2683 std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
2684 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
2685 throw e;
2686 }
2687
2688 return buf;
2689}
2690
2691static void ggml_vk_destroy_buffer(vk_buffer& buf) {
2692 if (buf == nullptr) {
2693 return;
2694 }
2695
2696 if (buf->device != nullptr) {
2697 buf->device->memory_logger->log_deallocation(buf);
2698 }
2699
2700 buf.reset();
2701}
2702
2703static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) {
2704 return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) };
2705}
2706
2707static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) {
2708 VK_LOG_DEBUG("ggml_vk_sync_buffers()");
2709
2710 const bool transfer_queue = subctx->p->q->transfer_only;
2711
2712 if (ctx) {
2713 ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
2714 }
2715
2716 subctx->s->buffer.pipelineBarrier(
2717 subctx->p->q->stage_flags,
2718 subctx->p->q->stage_flags,
2719 {},
2720 { {
2721 { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) },
2722 { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }
2723 } },
2724 {},
2725 {}
2726 );
2727}
2728
2729static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
2730 VK_LOG_DEBUG("ggml_vk_set_event()");
2731
2732 ctx->s->buffer.setEvent(
2733 event,
2734 ctx->p->q->stage_flags
2735 );
2736}
2737
2738static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events) {
2739 VK_LOG_DEBUG("ggml_vk_wait_events()");
2740 if (events.empty()) {
2741 return;
2742 }
2743
2744 ctx->s->buffer.waitEvents(
2745 events,
2746 ctx->p->q->stage_flags,
2747 ctx->p->q->stage_flags,
2748 {},
2749 {},
2750 {}
2751 );
2752}
2753
2754// number of rows/cols for flash attention shader
2755static constexpr uint32_t flash_attention_num_small_rows = 32;
2756static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
2757
2758static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
2759 if (hsv >= 192) {
2760 return 2;
2761 } else if ((hsv | hsk) & 8 || small_cache) {
2762 return 4;
2763 } else {
2764 return 8;
2765 }
2766}
2767
2768// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
2769// 128 threads split into four subgroups, each subgroup does 1/4
2770// of the Bc dimension.
2771static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
2772static constexpr uint32_t scalar_flash_attention_Bc = 64;
2773static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
2774
2775static uint32_t get_fa_num_small_rows(FaCodePath path) {
2776 if (path == FA_COOPMAT2) {
2777 return flash_attention_num_small_rows;
2778 } else {
2779 return scalar_flash_attention_num_small_rows;
2780 }
2781}
2782
2783static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
2784 GGML_UNUSED(clamp);
2785
2786 if (path == FA_SCALAR) {
2787 if (small_rows) {
2788 return {scalar_flash_attention_num_small_rows, 64};
2789 } else {
2790 if ((hsv | hsk) & 8) {
2791 // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
2792 // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2793 return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
2794 } else {
2795 return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
2796 }
2797 }
2798 }
2799
2800 if (path == FA_COOPMAT1) {
2801 if (small_rows) {
2802 return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
2803 } else {
2804 return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
2805 }
2806 }
2807
2808 // small rows, large cols
2809 if (small_rows) {
2810 return {get_fa_num_small_rows(FA_COOPMAT2), 32};
2811 }
2812
2813 // small cols to reduce register count
2814 if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
2815 if (hsk >= 512 || hsv >= 512) {
2816 return {32, 32};
2817 } else {
2818 return {64, 32};
2819 }
2820 }
2821 return {64, 64};
2822}
2823
2824static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
2825 return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
2826}
2827
2828static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
2829
2830 uint32_t lut_size = 0;
2831 switch (src0_type) {
2832 case GGML_TYPE_IQ1_S:
2833 case GGML_TYPE_IQ1_M:
2834 lut_size = 2*2048 + 4*2048;
2835 break;
2836 case GGML_TYPE_IQ2_XXS:
2837 lut_size = 8*256;
2838 break;
2839 case GGML_TYPE_IQ2_XS:
2840 lut_size = 8*512;
2841 break;
2842 case GGML_TYPE_IQ2_S:
2843 lut_size = 8*1024;
2844 break;
2845 case GGML_TYPE_IQ3_XXS:
2846 lut_size = 4*256;
2847 break;
2848 case GGML_TYPE_IQ3_S:
2849 lut_size = 4*512;
2850 break;
2851 case GGML_TYPE_IQ4_NL:
2852 case GGML_TYPE_IQ4_XS:
2853 case GGML_TYPE_MXFP4:
2854 lut_size = 4*16;
2855 break;
2856 default:
2857 break;
2858 }
2859
2860 // Needs to be kept up to date on shader changes
2861 const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
2862 const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
2863 const uint32_t warps = warptile[0] / warptile[10];
2864
2865 const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
2866 const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
2867 const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
2868 const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
2869
2870 const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
2871 const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
2872
2873 VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
2874 "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported);
2875
2876 return supported;
2877}
2878
2879struct GpuPipelineConfig {
2880 // GPU architecture identifier.
2881 // Example: vk_device_architecture::AMD_GCN
2882 vk_device_architecture arch;
2883
2884 // Mapping of pipeline names to their specific subgroup sizes.
2885 // Example: {"soft_max_f32", 64}
2886 std::unordered_map<std::string, uint32_t> pipelines;
2887
2888 // Default subgroup size for this GPU.
2889 // Defaults to 0 if not explicitly provided.
2890 uint32_t default_subgroup_size = 0;
2891};
2892
2893// Pipeline configuration for RDNA1 GPUs.
2894static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
2895 {"soft_max", 64}, {"im2col", 64},
2896 {"argmax", 64}, {"mul_mat_vec", 64},
2897 {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
2898};
2899
2900// Pipeline configuration for RDNA2 GPUs.
2901static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
2902 {"soft_max", 64}, {"im2col", 64},
2903};
2904
2905static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
2906
2907// Define configurations for different GPUs.
2908static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
2909 {
2910 vk_device_architecture::AMD_RDNA1,
2911 {
2912 rdna1_pipelines,
2913 },
2914 RDNA_DEFAULT_SUBGROUP_SIZE
2915 },
2916 {
2917 vk_device_architecture::AMD_RDNA2,
2918 {
2919 rdna2_pipelines,
2920 },
2921 RDNA_DEFAULT_SUBGROUP_SIZE
2922 },
2923};
2924
2925static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
2926 for (const auto &config : gpu_pipeline_configs) {
2927 if (config.arch == arch) {
2928 auto pipIt = config.pipelines.find(pipeline_name);
2929 if (pipIt != config.pipelines.end()) {
2930 return pipIt->second;
2931 }
2932 std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
2933 std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
2934 [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
2935 for (const auto &entry : sorted_pipelines) {
2936 if (pipeline_name.find(entry.first) != std::string::npos) {
2937 return entry.second;
2938 }
2939 }
2940 return config.default_subgroup_size;
2941 }
2942 }
2943 return 0; // If no matching configuration is found
2944}
2945
2946static void ggml_vk_load_shaders(vk_device& device) {
2947 VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
2948
2949 std::lock_guard<std::recursive_mutex> guard(device->mutex);
2950 // some shaders have a minimum subgroup size
2951 const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
2952 const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
2953 const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
2954
2955 const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
2956 const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u);
2957 const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u);
2958 const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u);
2959
2960 const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) ||
2961 (device->subgroup_size_control && device->subgroup_max_size >= 16);
2962
2963 // mulmat
2964 std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
2965 l_warptile_id, m_warptile_id, s_warptile_id,
2966 l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
2967 l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
2968 l_warptile_mmq_int_k, m_warptile_mmq_int_k, s_warptile_mmq_int_k,
2969 l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
2970 l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid,
2971 l_warptile_mmqid_int, m_warptile_mmqid_int, s_warptile_mmqid_int,
2972 l_warptile_mmqid_int_k, m_warptile_mmqid_int_k, s_warptile_mmqid_int_k;
2973 std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
2974 l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
2975 l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
2976 l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
2977
2978 uint32_t l_align, m_align, s_align;
2979 if (device->coopmat2) {
2980 // spec constants and tile sizes for non-quant matmul/matmul_id
2981 l_warptile = { 256, 128, 256, 64, 1 };
2982 m_warptile = { 256, 128, 128, 64, 0 };
2983 s_warptile = { 128, 64, 64, 64, 0 };
2984 l_wg_denoms = {128, 256, 1 };
2985 m_wg_denoms = {128, 128, 1 };
2986 s_wg_denoms = { 64, 64, 1 };
2987
2988 // spec constants and tile sizes for quant matmul (non-Qi_K)
2989 l_warptile_mmq = { 256, 128, 256, 64, 1 };
2990 m_warptile_mmq = { 256, 128, 128, 64, 1 };
2991 s_warptile_mmq = { 256, 32, 64, 128, 0 };
2992 l_mmq_wg_denoms = { 128, 256, 1 };
2993 m_mmq_wg_denoms = { 128, 128, 1 };
2994 s_mmq_wg_denoms = { 32, 64, 1 };
2995
2996 // spec constants and tile sizes for quant matmul (Qi_K)
2997 l_warptile_mmq_k = { 256, 128, 256, 64, 1 };
2998 m_warptile_mmq_k = { 256, 128, 128, 64, 1 };
2999 s_warptile_mmq_k = { 256, 32, 64, 128, 0 };
3000 l_mmq_wg_denoms_k = { 128, 256, 1 };
3001 m_mmq_wg_denoms_k = { 128, 128, 1 };
3002 s_mmq_wg_denoms_k = { 32, 64, 1 };
3003
3004 // spec constants and tile sizes for quant matmul_id
3005 l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
3006 m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
3007 s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
3008 l_mmqid_wg_denoms = { 128, 128, 1 };
3009 m_mmqid_wg_denoms = { 128, 64, 1 };
3010 s_mmqid_wg_denoms = { 128, 64, 1 };
3011
3012 l_align = 128;
3013 m_align = 64;
3014 s_align = 32;
3015 } else {
3016 // Matrix cores require different warp group sizes
3017 const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
3018 const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
3019 const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
3020 const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
3021 const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
3022 const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
3023 const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
3024 const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
3025 const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
3026
3027 const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
3028
3029 l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
3030 m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
3031 s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
3032
3033 l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
3034 m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
3035 s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
3036
3037 // Integer MMQ has a smaller shared memory profile, but heavier register use
3038 l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
3039 m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
3040 s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 };
3041
3042 // K-quants use even more registers, mitigate by setting WMITER to 1
3043 l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
3044 m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
3045 s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 };
3046
3047 l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
3048 m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
3049 s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
3050
3051 l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
3052 m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
3053 s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
3054
3055 l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
3056 m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
3057 s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
3058
3059 l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
3060 m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
3061 s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
3062
3063 // chip specific tuning
3064 if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
3065 m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
3066 m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
3067 } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary) {
3068 // This is intentionally using tx_m values, slight performance increase
3069 l_warptile = { 256, 128, 128, 16, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
3070 l_warptile_mmq = l_warptile_mmq_int = { 256, 128, 128, 32, subgroup_size_8, 64, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
3071 l_warptile_mmq_int_k = { 256, 128, 128, 32, subgroup_size_16, 64, 1, 4, 2, 1, subgroup_size_16 };
3072 } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
3073 // Xe2/Xe3 with coopmat enabled - warptile performance tuning
3074 l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
3075 l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
3076 }
3077
3078 l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
3079 m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
3080 s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
3081 l_align = 128;
3082 m_align = 64;
3083 s_align = 32;
3084
3085 for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
3086 ggml_type t = (ggml_type)i;
3087 // Disable medium and large matrix multiplication if not enough shared memory is available
3088 // Check mmq warptiles as the largest configuration
3089 // Throw an error if not enough for any matrix multiplication is available
3090 if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) {
3091 std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
3092 throw std::runtime_error("Shared memory size too small for matrix multiplication.");
3093 } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) {
3094 device->mul_mat_m[i] = false;
3095 device->mul_mat_l[i] = false;
3096 } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) {
3097 device->mul_mat_l[i] = false;
3098 }
3099
3100 // Disable mul_mat_id if not enough shared memory is available
3101 if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) {
3102 device->mul_mat_id_s[i] = false;
3103 device->mul_mat_id_m[i] = false;
3104 device->mul_mat_id_l[i] = false;
3105 } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) {
3106 device->mul_mat_id_m[i] = false;
3107 device->mul_mat_id_l[i] = false;
3108 } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
3109 device->mul_mat_id_l[i] = false;
3110 }
3111 }
3112 }
3113
3114 if (!device->pipeline_matmul_f32) {
3115 device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
3116 }
3117 if (!device->pipeline_matmul_f32_f16) {
3118 device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
3119 }
3120 if (!device->pipeline_matmul_id_f32) {
3121 device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
3122 }
3123 if (!device->pipeline_matmul_bf16) {
3124 device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
3125 }
3126 if (!device->pipeline_matmul_id_bf16) {
3127 device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
3128 }
3129
3130 std::vector<std::future<void>> compiles;
3131 auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
3132 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
3133 uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
3134
3135 if (!require_full_subgroups && required_subgroup_size == 0) {
3136 required_subgroup_size = get_subgroup_size(name, device->architecture);
3137 }
3138
3139 vk_pipeline *ptr = &base_pipeline;
3140
3141 int num_pipelines = 1;
3142#if defined(VK_EXT_shader_64bit_indexing)
3143 if (device->shader_64b_indexing) {
3144 num_pipelines = 2;
3145 }
3146#endif
3147 for (int i = 0; i < num_pipelines; ++i, ptr = &(*ptr)->next) {
3148 vk_pipeline &pipeline = *ptr;
3149 if (!pipeline) {
3150 pipeline = std::make_shared<vk_pipeline_struct>();
3151 }
3152 if (!pipeline->initialized) {
3153 pipeline->name = name;
3154 pipeline->parameter_count = parameter_count;
3155 pipeline->push_constant_size = push_constant_size;
3156 pipeline->wg_denoms = wg_denoms;
3157 pipeline->align = align;
3158 pipeline->initialized = true;
3159#if defined(VK_EXT_shader_64bit_indexing)
3160 pipeline->is_64b_indexing = (i == 1);
3161#endif
3162 }
3163
3164 if (!pipeline->needed || pipeline->compiled) {
3165 continue;
3166 }
3167 // TODO: We're no longer benefitting from the async compiles (shaders are
3168 // compiled individually, as needed) and this complexity can be removed.
3169 {
3170 // wait until fewer than N compiles are in progress
3171 uint32_t N = std::max(1u, std::thread::hardware_concurrency());
3172 std::unique_lock<std::mutex> guard(compile_count_mutex);
3173 while (compile_count >= N) {
3174 compile_count_cond.wait(guard);
3175 }
3176 compile_count++;
3177 }
3178
3179 compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
3180 parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
3181 }
3182 };
3183
3184 auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint,
3185 uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
3186 uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
3187 return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint,
3188 parameter_count, push_constant_size, wg_denoms, specialization_constants,
3189 align, disable_robustness, require_full_subgroups, required_subgroup_size);
3190 };
3191
3192 auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
3193 return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
3194 };
3195
3196 auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
3197 // For large number of rows, 128 invocations seems to work best.
3198 // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
3199 // can't use 256 for D==80.
3200 // For scalar, use 128 (arbitrary)
3201 // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
3202 const uint32_t D = (hsk|hsv);
3203 auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
3204
3205 uint32_t wg_size;
3206 switch (path) {
3207 case FA_COOPMAT2:
3208 wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
3209 break;
3210 case FA_COOPMAT1:
3211 wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
3212 break;
3213 default:
3214 wg_size = scalar_flash_attention_workgroup_size;
3215 break;
3216 }
3217
3218 // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
3219 // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
3220 const uint32_t D_lsb = D ^ (D & (D-1));
3221 uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
3222
3223 // Nvidia prefers shared memory use to load large tiles of K.
3224 // Switch to loading from global memory when it would use too much shared memory.
3225 // AMD prefers loading K directly from global memory
3226 const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
3227
3228 return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
3229 };
3230
3231#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
3232 for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
3233 uint32_t HSK = fa.first.HSK; \
3234 uint32_t HSV = fa.first.HSV; \
3235 bool small_rows = fa.first.small_rows; \
3236 bool small_cache = fa.first.small_cache; \
3237 FaCodePath path = fa.first.path; \
3238 bool aligned = fa.first.aligned; \
3239 bool f32acc = fa.first.f32acc; \
3240 uint32_t flags = fa.first.flags; \
3241 if (path == FAPATH) { \
3242 if (aligned) { \
3243 if (f32acc) { \
3244 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3245 } else { \
3246 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3247 } \
3248 } else { \
3249 if (f32acc) { \
3250 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3251 } else { \
3252 ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \
3253 } \
3254 } \
3255 } \
3256 }
3257
3258 CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
3259 CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
3260 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
3261 CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
3262#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3263 if (device->coopmat1_fa_support) {
3264 CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
3265 CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
3266 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
3267 CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
3268 }
3269#endif
3270#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3271 if (device->coopmat2) {
3272 CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
3273 CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
3274 CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
3275 CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
3276 CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
3277 CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
3278 CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
3279 CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
3280 }
3281#endif
3282#undef CREATE_FA
3283
3284 const int mul_mat_id_param_count = 5;
3285
3286#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3287 if (device->coopmat2) {
3288
3289 // Create 6 variants, {s,m,l}x{unaligned,aligned}
3290#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
3291 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \
3292 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \
3293 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \
3294 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \
3295 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \
3296 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \
3297
3298 // Create 2 variants, {f16,f32} accumulator
3299#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
3300 CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
3301 CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
3302
3303 CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
3304#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3305 if (device->coopmat_bf16_support) {
3306 CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
3307 }
3308#endif
3309 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3310 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3311 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3312 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3313 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3314 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
3315 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
3316 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
3317 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
3318 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
3319 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3320 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3321 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3322 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3323 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3324 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3325 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3326 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3327 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3328 CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3329
3330 GGML_ASSERT(device->subgroup_ballot);
3331
3332 CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
3333#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3334 if (device->coopmat_bf16_support) {
3335 CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
3336 }
3337#endif
3338 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3339 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3340 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3341 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3342 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3343 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3344 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3345 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3346 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3347 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3348 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3349 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3350 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3351 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3352 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3353 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3354 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3355 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3356 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3357 CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3358#undef CREATE_MM
3359#undef CREATE_MM2
3360 } else
3361#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3362#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3363 if (device->coopmat_support) {
3364 // Create 6 variants, {s,m,l}x{unaligned,aligned}
3365#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3366 if (device->mul_mat ## ID ## _l[TYPE]) \
3367 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
3368 if (device->mul_mat ## ID ## _m[TYPE]) \
3369 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
3370 if (device->mul_mat ## ID ## _s[TYPE]) \
3371 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
3372 if (device->mul_mat ## ID ## _l[TYPE]) \
3373 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
3374 if (device->mul_mat ## ID ## _m[TYPE]) \
3375 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
3376 if (device->mul_mat ## ID ## _s[TYPE]) \
3377 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
3378
3379 // Create 2 variants, {f16,f32} accumulator
3380#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3381 if (device->coopmat_acc_f16_support) { \
3382 CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3383 } \
3384 if (device->coopmat_acc_f32_support) { \
3385 CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3386 } \
3387
3388 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
3389 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
3390 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
3391 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
3392#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3393 if (device->coopmat_bf16_support) {
3394 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
3395 }
3396#endif
3397
3398 if (device->coopmat_acc_f16_support) {
3399 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3400 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3401 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3402 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3403 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3404
3405 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3406 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3407 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3408 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3409 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3410 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3411 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3412 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3413 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3414 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3415 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3416 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3417 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3418 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3419 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3420 } else {
3421 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3422 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3423 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3424 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3425 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3426
3427 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3428 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3429 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3430 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3431 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3432 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3433 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3434 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3435 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3436 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3437 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3438 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3439 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3440 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3441 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3442 }
3443
3444 GGML_ASSERT(device->subgroup_ballot);
3445
3446 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3447 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3448 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3449#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3450 if (device->coopmat_bf16_support) {
3451 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3452 }
3453#endif
3454
3455 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3456 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3457 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3458 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3459 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3460 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3461 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3462 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3463 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3464 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3465 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3466 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3467 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3468 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3469 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3470 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3471 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3472 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3473 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3474 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3475#undef CREATE_MM2
3476#undef CREATE_MM
3477 } else
3478#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3479 if (device->fp16) {
3480 // Create 6 variants, {s,m,l}x{unaligned,aligned}
3481#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3482 if (device->mul_mat ## ID ## _l[TYPE]) \
3483 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3484 if (device->mul_mat ## ID ## _m[TYPE]) \
3485 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3486 if (device->mul_mat ## ID ## _s[TYPE]) \
3487 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3488 if (device->mul_mat ## ID ## _l[TYPE]) \
3489 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3490 if (device->mul_mat ## ID ## _m[TYPE]) \
3491 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3492 if (device->mul_mat ## ID ## _s[TYPE]) \
3493 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3494
3495#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3496 if (device->mul_mat ## ID ## _l[TYPE]) { \
3497 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3498 } \
3499 if (device->mul_mat ## ID ## _m[TYPE]) { \
3500 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3501 } \
3502 if (device->mul_mat ## ID ## _s[TYPE]) { \
3503 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3504 } \
3505
3506 // Create 2 variants, {f16,f32} accumulator
3507#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3508 CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3509 CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3510
3511 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3512 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3513 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3514 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3515
3516 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3517
3518 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3519 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3520 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3521 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3522 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3523
3524 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3525 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3526 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3527 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3528 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3529 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3530 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3531 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3532 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3533 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3534 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3535 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3536 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3537 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3538 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3539
3540#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3541 if (device->integer_dot_product) {
3542 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3543 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3544 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3545 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3546 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3547
3548 CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_MXFP4], matmul_mxfp4_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, , 0);
3549
3550 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K], matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3551 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K], matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3552 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K], matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3553 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K], matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3554 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K], matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, , 0);
3555 }
3556#endif
3557
3558 if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
3559 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3560 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3561 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3562 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3563
3564 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3565 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3566 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3567 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3568 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3569 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3570 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3571 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3572 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3573 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3574 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3575 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3576 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3577 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3578 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3579 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3580 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3581 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3582 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3583 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3584
3585#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3586 if (device->integer_dot_product) {
3587 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3588 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3589 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3590 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3591 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3592
3593 CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3594
3595 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3596 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3597 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3598 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3599 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3600 }
3601#endif
3602 } else {
3603 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3604 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3605 CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3606 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3607
3608 CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3609 CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3610 CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3611 CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3612 CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3613 CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3614 CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3615 CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3616 CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3617 CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3618 CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3619 CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3620 CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3621 CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3622 CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3623 CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3624 CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3625 CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3626 CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3627 CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3628
3629#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3630 if (device->integer_dot_product) {
3631 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_0], matmul_id_q4_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3632 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_1], matmul_id_q4_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3633 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_0], matmul_id_q5_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3634 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_1], matmul_id_q5_1_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3635 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q8_0], matmul_id_q8_0_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3636
3637 CREATE_MMQ(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_MXFP4], matmul_id_mxfp4_q8_1, mmq_wg_denoms, warptile_mmqid_int, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3638
3639 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q2_K], matmul_id_q2_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3640 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q3_K], matmul_id_q3_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3641 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q4_K], matmul_id_q4_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3642 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q5_K], matmul_id_q5_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3643 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id_q8_1[GGML_TYPE_Q6_K], matmul_id_q6_k_q8_1, mmq_wg_denoms, warptile_mmqid_int_k, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3644 }
3645#endif
3646 }
3647#undef CREATE_MM2
3648#undef CREATE_MMQ
3649#undef CREATE_MM
3650 } else {
3651 // Create 6 variants, {s,m,l}x{unaligned,aligned}
3652#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3653 if (device->mul_mat ## ID ## _l[TYPE]) \
3654 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3655 if (device->mul_mat ## ID ## _m[TYPE]) \
3656 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3657 if (device->mul_mat ## ID ## _s[TYPE]) \
3658 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3659 if (device->mul_mat ## ID ## _l[TYPE]) \
3660 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3661 if (device->mul_mat ## ID ## _m[TYPE]) \
3662 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3663 if (device->mul_mat ## ID ## _s[TYPE]) \
3664 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3665
3666#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3667 if (device->mul_mat ## ID ## _l[TYPE]) \
3668 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
3669 if (device->mul_mat ## ID ## _m[TYPE]) \
3670 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
3671 if (device->mul_mat ## ID ## _s[TYPE]) \
3672 ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
3673
3674 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3675 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3676 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3677 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3678
3679 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3680
3681 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3682 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3683 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3684 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3685 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3686
3687 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3688 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3689 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3690 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3691 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3692 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3693 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3694 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3695 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3696 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3697 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3698 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3699 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3700 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3701 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3702
3703#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3704 if (device->integer_dot_product) {
3705 CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3706 CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3707 CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3708 CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3709 CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
3710
3711 CREATE_MMQ(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3712 CREATE_MMQ(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3713 CREATE_MMQ(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3714 CREATE_MMQ(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3715 CREATE_MMQ(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_q8_1, mmq_wg_denoms, warptile_mmq_int_k, vk_mat_mat_push_constants, 3, );
3716 }
3717#endif
3718
3719 if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) {
3720 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3721 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3722 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3723 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3724
3725 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3726 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3727 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3728 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3729 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3730 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3731 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3732 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3733 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3734 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3735 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3736 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3737 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3738 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3739 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3740 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3741 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3742 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3743 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3744 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3745 } else {
3746 CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3747 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3748 CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3749 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3750
3751 CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3752 CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3753 CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3754 CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3755 CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3756 CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3757 CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3758 CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3759 CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3760 CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3761 CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3762 CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3763 CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3764 CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3765 CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3766 CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3767 CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3768 CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3769 CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3770 CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3771 }
3772 }
3773 // reusing CREATE_MM from the fp32 path
3774 if ((device->coopmat2 || device->coopmat_support)
3775#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3776 && !device->coopmat_bf16_support
3777#endif
3778 ) {
3779 // use scalar tile sizes
3780 l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
3781 m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
3782 s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
3783
3784 l_wg_denoms = {128, 128, 1 };
3785 m_wg_denoms = { 64, 64, 1 };
3786 s_wg_denoms = { 32, 32, 1 };
3787
3788 if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
3789 // Xe2/Xe3 - bf16 warptile performance tuning
3790 l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
3791 }
3792
3793 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3794 CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3795 }
3796#undef CREATE_MM
3797
3798 // mul mat vec
3799
3800 // the number of rows computed per shader depends on GPU model and quant
3801 uint32_t rm_stdq = 1;
3802 uint32_t rm_kq = 2;
3803 uint32_t rm_stdq_int = 1;
3804 uint32_t rm_kq_int = 1;
3805 auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
3806 if (device->vendor_id == VK_VENDOR_ID_AMD) {
3807 if (device->architecture == AMD_GCN) {
3808 rm_stdq = 2;
3809 rm_kq = 4;
3810 rm_stdq_int = 4;
3811 }
3812 } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
3813 rm_stdq = 2;
3814 rm_stdq_int = 2;
3815 }
3816 uint32_t rm_iq = 2 * rm_kq;
3817
3818 const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
3819 // Ensure a subgroup size >= 16 is available
3820 const bool use_subgroups16 = use_subgroups && subgroup_min_size_16;
3821
3822 const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size;
3823 const uint32_t subgroup_size16 = std::max(subgroup_size, 16u);
3824
3825 const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
3826 const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
3827 static constexpr uint32_t mul_mat_vec_num_bindings = 5;
3828 static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
3829
3830 for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
3831 const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
3832 const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4);
3833
3834 const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
3835 (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
3836 SHADER_REDUCTION_MODE_SHMEM;
3837
3838 const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP :
3839 (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID :
3840 SHADER_REDUCTION_MODE_SHMEM;
3841
3842 for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
3843 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
3844 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3845 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3846 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3847 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3848 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3849 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3850 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3851 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3852 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3853 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3854 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3855 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3856 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3857 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3858 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3859 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3860 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3861 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3862 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3863 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3864 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3865 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3866
3867 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
3868 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3869 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
3870 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3871 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3872 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3873 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3874 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
3875 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3876 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3877 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3878 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3879 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3880 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3881 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3882 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3883 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3884 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3885 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3886 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3887 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3888 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3889 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
3890
3891#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3892 if (device->integer_dot_product) {
3893 const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
3894 const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
3895
3896 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3897 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3898 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3899 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3900 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3901
3902 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3903
3904 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3905 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3906 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3907 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3908 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
3909
3910 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
3911 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
3912
3913 }
3914#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3915 }
3916
3917 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
3918 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
3919 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
3920 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
3921 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
3922 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
3923 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
3924 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
3925 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
3926 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
3927 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
3928 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
3929 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
3930 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3931 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3932 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3933 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3934 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3935 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3936 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3937 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3938 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3939 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
3940
3941#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3942 if (device->integer_dot_product) {
3943 const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
3944 const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
3945
3946 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
3947 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
3948 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
3949 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
3950 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
3951
3952 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
3953
3954 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3955 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3956 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3957 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3958 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
3959
3960 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
3961 ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
3962 }
3963#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
3964 }
3965
3966#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3967 GGML_UNUSED(rm_stdq_int);
3968 GGML_UNUSED(rm_kq_int);
3969 GGML_UNUSED(rm_iq_int);
3970#endif
3971
3972 // dequant shaders
3973 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3974 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3975 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3976 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3977 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3978 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3979 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3980 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3981 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3982 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3983 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1);
3984 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3985 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3986 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3987 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3988 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3989 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3990 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3991 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
3992 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3993 ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
3994
3995 // get_rows
3996 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3997 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3998 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
3999 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4000 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4001 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4002 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4003 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4004 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4005 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4006 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4007 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4008 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4009 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4010 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4011 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4012 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4013 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4014 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4015 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4016 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4017 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4018 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4019 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4020
4021 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4022 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4023 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4024 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4025 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4026 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4027 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4028 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4029 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4030 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4031 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4032 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4033 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4034 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4035 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4036 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4037 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4038 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4039 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4040 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4041 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4042 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4043 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4044
4045 ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
4046 ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
4047
4048 for (auto &it : device->pipeline_fa_mask_opt) {
4049 auto BrBc = it.first;
4050 ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
4051 }
4052
4053 if (device->subgroup_clustered && device->subgroup_require_full_support) {
4054 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
4055 } else {
4056 ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
4057 }
4058
4059 for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
4060 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
4061 ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
4062 } else {
4063 ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_p021_push_constants), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
4064 }
4065 }
4066 ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1);
4067
4068 ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
4069 ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
4070
4071 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
4072 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4073 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
4074 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4075
4076 if (device->float_controls_rte_fp16 &&
4077 sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
4078 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4079 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4080 }
4081
4082 ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
4083 ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
4084
4085 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4086 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4087 ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4088 ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4089 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4090 ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4091 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4092
4093 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4094 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4095 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4096 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4097 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4098 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4099 ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4100
4101 ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
4102 ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
4103
4104 if (device->float_controls_rte_fp16) {
4105 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4106 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4107 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4108 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4109 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4110 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4111 } else {
4112 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4113 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4114 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4115 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4116 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4117 ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4118 }
4119
4120#define SET_ROWS(itype, rte) \
4121 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4122 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4123 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4124 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4125 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4126 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4127 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4128 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4129 ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
4130
4131 if (device->float_controls_rte_fp16) {
4132 SET_ROWS(_i32, _rte)
4133 SET_ROWS(_i64, _rte)
4134 } else {
4135 SET_ROWS(_i32, )
4136 SET_ROWS(_i64, )
4137 }
4138#undef SET_ROWS
4139
4140
4141 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
4142 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
4143 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
4144 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
4145 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
4146 ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
4147
4148 auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
4149 std::string s;
4150 s += std::string(src0_f16 ? "_f16" : "_f32");
4151 s += std::string(src1_f16 ? "_f16" : "_f32");
4152 s += std::string(dst_f16 ? "_f16" : "_f32");
4153 return s;
4154 };
4155
4156 bool rte = device->float_controls_rte_fp16;
4157#define CREATE_BINARY(name, namemod, spec, bindings) \
4158 for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
4159 ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
4160 #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
4161 "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
4162
4163 CREATE_BINARY(add, , {0}, 4)
4164 CREATE_BINARY(add, _norepeat, {1}, 4)
4165 CREATE_BINARY(sub, , {0}, 3)
4166 CREATE_BINARY(sub, _norepeat, {1}, 3)
4167 CREATE_BINARY(mul, , {0}, 3)
4168 CREATE_BINARY(mul, _norepeat, {1}, 3)
4169 CREATE_BINARY(div, , {0}, 3)
4170 CREATE_BINARY(div, _norepeat, {1}, 3)
4171 CREATE_BINARY(add_rms, , {0}, 4)
4172 CREATE_BINARY(add_rms, _norepeat, {1}, 4)
4173#undef CREATE_BINARY
4174
4175 if (device->multi_add) {
4176 for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) {
4177 ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
4178 ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1);
4179 }
4180 }
4181
4182 ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
4183
4184 ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4185
4186 ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4187 ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4188 ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4189
4190 ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1);
4191 ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1);
4192 ggml_vk_create_pipeline(device, device->pipeline_upscale_bicubic_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BICUBIC}, 1);
4193 ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_antialias_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS}, 1);
4194
4195 ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4196
4197 ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4198 ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4199 ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4200 ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4201
4202 if (device->float_controls_rte_fp16) {
4203 ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4204 ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4205 } else {
4206 ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4207 ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4208 }
4209
4210 ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4211 ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4212
4213 ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4214 ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4215
4216 ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4217
4218 ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
4219
4220 ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4221
4222 ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4223 ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4224
4225#define CREATE_UNARY(name) \
4226 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4227 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4228
4229 CREATE_UNARY(gelu)
4230 CREATE_UNARY(gelu_erf)
4231 CREATE_UNARY(gelu_quick)
4232 CREATE_UNARY(silu)
4233 CREATE_UNARY(relu)
4234 CREATE_UNARY(xielu)
4235 CREATE_UNARY(neg)
4236 CREATE_UNARY(tanh)
4237 CREATE_UNARY(sigmoid)
4238 CREATE_UNARY(hardsigmoid)
4239 CREATE_UNARY(hardswish)
4240 CREATE_UNARY(abs)
4241 CREATE_UNARY(softplus)
4242 CREATE_UNARY(step)
4243 CREATE_UNARY(round)
4244 CREATE_UNARY(ceil)
4245 CREATE_UNARY(floor)
4246 CREATE_UNARY(trunc)
4247#undef CREATE_UNARY
4248
4249#define CREATE_UNARY_RTE(name) \
4250 if (device->float_controls_rte_fp16) { \
4251 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4252 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4253 } else { \
4254 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4255 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4256 }
4257 CREATE_UNARY_RTE(exp)
4258#undef CREATE_UNARY_RTE
4259
4260 ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4261 ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4262 ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4263
4264 ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4265
4266 ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4267
4268#define CREATE_GLU(name) \
4269 if (device->float_controls_rte_fp16) { \
4270 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4271 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4272 } else { \
4273 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4274 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4275 }
4276
4277 CREATE_GLU(geglu)
4278 CREATE_GLU(reglu)
4279 CREATE_GLU(swiglu)
4280 CREATE_GLU(swiglu_oai)
4281 CREATE_GLU(geglu_erf)
4282 CREATE_GLU(geglu_quick)
4283#undef CREATE_GLU
4284
4285 ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4286 ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4287
4288 ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
4289
4290 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
4291 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
4292 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
4293 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
4294 ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
4295
4296 ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4297 ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4298 ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4299 ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4300 ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4301 ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4302
4303 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4304 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4305 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4306 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4307
4308 if (device->float_controls_rte_fp16) {
4309 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4310 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4311 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4312 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4313
4314 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4315 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4316 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4317 } else {
4318 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4319 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4320 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4321 ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4322
4323 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4324 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4325 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4326 }
4327
4328 for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
4329 uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
4330 if (i <= device->max_workgroup_size_log2 &&
4331 2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
4332 const uint32_t NCOLS_PADDED_LOG2 = i;
4333 ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
4334 }
4335 const uint32_t WG_UNROLL_FACTOR = BLOCK_SIZE > 1 ? 2 : 1;
4336 BLOCK_SIZE /= WG_UNROLL_FACTOR;
4337 ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
4338 }
4339
4340 for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
4341 const uint32_t BLOCK_SIZE = 1u << i;
4342 const uint32_t NCOLS_PADDED_LOG2 = i;
4343 if (i <= device->max_workgroup_size_log2) {
4344 uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
4345 sizeof(int) * device->subgroup_size +
4346 2 * sizeof(int) +
4347 2 * (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
4348 if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
4349 nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
4350 ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
4351 } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
4352 ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
4353 }
4354 }
4355 }
4356
4357 ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
4358
4359 ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
4360
4361 const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
4362 ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
4363 ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
4364 ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
4365 ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
4366
4367 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
4368
4369 ggml_vk_create_pipeline(device, device->pipeline_count_experts, "count_experts", count_experts_len, count_experts_data, "main", 2, sizeof(vk_op_count_experts_push_constants), {1, 1, 1}, {}, 1, true);
4370
4371 for (auto &s : device->pipeline_solve_tri_f32) {
4372 const vk_solve_tri_pipeline_state &state = s.first;
4373
4374 // Max number of rows to load at a time, limited by shared memory
4375 const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((state.N + state.K) * sizeof(float));
4376 // Need at least K invocations, and prefer a minimum of 128 to spread out loading shared memory
4377 const uint32_t block_size = std::max(128u, 1u << (uint32_t)ceilf(log2f(float(state.K))));
4378
4379 ggml_vk_create_pipeline(
4380 device, s.second, "solve_tri_f32",
4381 solve_tri_f32_len, solve_tri_f32_data, "main", 3,
4382 sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K, batch_N, block_size }, 1, true);
4383 }
4384
4385#define IM2COL(bda) \
4386 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
4387 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
4388 if (device->float_controls_rte_fp16) { \
4389 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
4390 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
4391 } else { \
4392 ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
4393 ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
4394 }
4395 if (device->shader_int64 && device->buffer_device_address) {
4396 IM2COL(_bda)
4397 } else {
4398 IM2COL()
4399 }
4400
4401 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
4402
4403 ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
4404
4405 ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
4406
4407 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
4408
4409 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
4410
4411 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
4412 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
4413 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
4414 } else {
4415 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
4416 ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
4417 }
4418
4419 ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
4420
4421 ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4422
4423 ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4424
4425 // conv2d, conv_transpose_2d
4426 for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
4427 uint32_t conv2d_WG_SIZE = 256;
4428 uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
4429 uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8;
4430 uint32_t conv2d_SHMEM_PAD = 4;
4431 vk_conv_block_size conv2d_BS = vk_conv_block_sizes[s];
4432 bool conv2d_UNROLL = true;
4433
4434#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4435 if (device->coopmat2) {
4436 conv2d_SHMEM_PAD = 8; // 8 float16_t
4437 }
4438#endif
4439
4440 if (device->vendor_id == VK_VENDOR_ID_INTEL) {
4441 conv2d_SHMEM_PAD = 0;
4442 conv2d_UNROLL = false;
4443 } else if (device->vendor_id == VK_VENDOR_ID_AMD) {
4444 conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4;
4445 if (s == CONV_SHAPE_128x128 && device->architecture != vk_device_architecture::AMD_GCN) {
4446 conv2d_UNROLL = false;
4447 }
4448 }
4449
4450 // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math.
4451 bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA ||
4452 device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
4453 bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD ||
4454 device->architecture == vk_device_architecture::AMD_GCN;
4455
4456 if (device->subgroup_shuffle &&
4457 device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316.
4458 allow_collectives_nv &&
4459 allow_collectives_amd) {
4460 use_collectives = 1;
4461 conv2d_BS.CRS = std::min(
4462 device->subgroup_size,
4463 conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
4464 }
4465
4466 uint32_t conv2d_shmem_req =
4467 (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
4468 if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
4469 conv2d_BS.CRS = 8;
4470 if (use_collectives) {
4471 conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
4472 }
4473 }
4474
4475 std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };
4476 std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
4477
4478#define CREATE_CONV(name, type_suffix, spv_suffix) \
4479 for (auto &c : device->pipeline_##name##type_suffix[s]) { \
4480 const vk_conv2d_pipeline_state &state = c.first; \
4481 std::vector<uint32_t> spec_constants_cpy = spec_constants; \
4482 spec_constants_cpy.push_back(state.s0); \
4483 spec_constants_cpy.push_back(state.s1); \
4484 spec_constants_cpy.push_back(state.p0); \
4485 spec_constants_cpy.push_back(state.p1); \
4486 spec_constants_cpy.push_back(state.d0); \
4487 spec_constants_cpy.push_back(state.d1); \
4488 spec_constants_cpy.push_back(state.KW); \
4489 spec_constants_cpy.push_back(state.KH); \
4490 ggml_vk_create_pipeline( \
4491 device, c.second, #name #type_suffix, \
4492 name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
4493 sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
4494 }
4495#define CREATE_CONVS(spv_suffix) \
4496 CREATE_CONV(conv2d, _f32, spv_suffix) \
4497 CREATE_CONV(conv2d, _f16_f32, spv_suffix) \
4498 CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \
4499 CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix)
4500#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4501 if (device->coopmat2) {
4502 CREATE_CONVS(_cm2)
4503 } else
4504#endif
4505 if (conv2d_UNROLL) {
4506 CREATE_CONVS(_unroll)
4507 } else {
4508 CREATE_CONVS( )
4509 }
4510#undef CREATE_CONV
4511#undef CREATE_CONVS
4512 }
4513
4514 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
4515 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
4516 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
4517 ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
4518
4519 for (uint32_t use_push = 0; use_push < 2; ++use_push) {
4520 for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4521 ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
4522 }
4523 }
4524
4525 for (auto &c : compiles) {
4526 c.wait();
4527 }
4528}
4529
4530static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
4531
4532static vk_device ggml_vk_get_device(size_t idx) {
4533 VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
4534
4535 if (vk_instance.devices[idx] == nullptr) {
4536 VK_LOG_DEBUG("Initializing new vk_device");
4537 vk_device device = std::make_shared<vk_device_struct>();
4538 vk_instance.devices[idx] = device;
4539
4540 device->memory_logger = std::unique_ptr<vk_memory_logger>(new vk_memory_logger());
4541
4542 size_t dev_num = vk_instance.device_indices[idx];
4543
4544 std::vector<vk::PhysicalDevice> physical_devices = vk_instance.instance.enumeratePhysicalDevices();
4545
4546 if (dev_num >= physical_devices.size()) {
4547 std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
4548 throw std::runtime_error("Device not found");
4549 }
4550
4551 device->physical_device = physical_devices[dev_num];
4552 const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
4553
4554 device->architecture = get_device_architecture(device->physical_device);
4555
4556 const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
4557 device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
4558
4559 const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM");
4560 device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr;
4561
4562 const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
4563 device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
4564
4565 const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE");
4566 device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr;
4567
4568 bool fp16_storage = false;
4569 bool fp16_compute = false;
4570 bool maintenance4_support = false;
4571 bool sm_builtins = false;
4572 bool amd_shader_core_properties2 = false;
4573 bool pipeline_robustness = false;
4574 bool coopmat2_support = false;
4575 bool pipeline_executable_properties_support = false;
4576 device->coopmat_support = false;
4577 device->integer_dot_product = false;
4578 device->shader_64b_indexing = false;
4579 bool bfloat16_support = false;
4580
4581 for (const auto& properties : ext_props) {
4582 if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
4583 maintenance4_support = true;
4584 } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
4585 fp16_storage = true;
4586 } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
4587 fp16_compute = true;
4588 } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
4589 sm_builtins = true;
4590 } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
4591 amd_shader_core_properties2 = true;
4592 } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
4593 pipeline_robustness = true;
4594 } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
4595 device->subgroup_size_control = true;
4596#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
4597 } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
4598 !getenv("GGML_VK_DISABLE_COOPMAT")) {
4599 device->coopmat_support = true;
4600 device->coopmat_m = 0;
4601 device->coopmat_n = 0;
4602 device->coopmat_k = 0;
4603#endif
4604#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4605 } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
4606 !getenv("GGML_VK_DISABLE_COOPMAT2")) {
4607 coopmat2_support = true;
4608#endif
4609#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
4610 } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
4611 !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
4612 device->integer_dot_product = true;
4613#endif
4614#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
4615 } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
4616 !getenv("GGML_VK_DISABLE_BFLOAT16")) {
4617 bfloat16_support = true;
4618#endif
4619 } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
4620 pipeline_executable_properties_support = true;
4621 } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
4622 getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
4623 device->memory_priority = true;
4624 } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
4625 device->external_memory_host = true;
4626#if defined(VK_EXT_shader_64bit_indexing)
4627 } else if (strcmp("VK_EXT_shader_64bit_indexing", properties.extensionName) == 0) {
4628 device->shader_64b_indexing = true;
4629#endif
4630 }
4631 }
4632
4633 vk::PhysicalDeviceProperties2 props2;
4634 vk::PhysicalDeviceMaintenance3Properties props3;
4635 vk::PhysicalDeviceMaintenance4Properties props4;
4636 vk::PhysicalDeviceSubgroupProperties subgroup_props;
4637 vk::PhysicalDeviceDriverProperties driver_props;
4638 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
4639 vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
4640 vk::PhysicalDeviceVulkan11Properties vk11_props;
4641 vk::PhysicalDeviceVulkan12Properties vk12_props;
4642 vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
4643 vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
4644 vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
4645
4646 props2.pNext = &props3;
4647 props3.pNext = &subgroup_props;
4648 subgroup_props.pNext = &driver_props;
4649 driver_props.pNext = &vk11_props;
4650 vk11_props.pNext = &vk12_props;
4651
4652 VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
4653
4654 if (maintenance4_support) {
4655 last_struct->pNext = (VkBaseOutStructure *)&props4;
4656 last_struct = (VkBaseOutStructure *)&props4;
4657 }
4658 if (sm_builtins) {
4659 last_struct->pNext = (VkBaseOutStructure *)&sm_props;
4660 last_struct = (VkBaseOutStructure *)&sm_props;
4661 }
4662 if (amd_shader_core_properties2) {
4663 last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
4664 last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
4665 }
4666 if (device->subgroup_size_control) {
4667 last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
4668 last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
4669 }
4670
4671#if defined(VK_NV_cooperative_matrix2)
4672 vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
4673 if (coopmat2_support) {
4674 last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
4675 last_struct = (VkBaseOutStructure *)&coopmat2_props;
4676 }
4677#endif
4678
4679 if (device->integer_dot_product) {
4680 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
4681 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
4682 }
4683
4684 if (device->external_memory_host) {
4685 last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
4686 last_struct = (VkBaseOutStructure *)&external_memory_host_props;
4687 }
4688
4689 device->physical_device.getProperties2(&props2);
4690 device->properties = props2.properties;
4691 device->vendor_id = device->properties.vendorID;
4692 device->driver_id = driver_props.driverID;
4693
4694 if (device->driver_id == vk::DriverId::eMoltenvk) {
4695 // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
4696 // is available in the Vulkan SDK.
4697 device->external_memory_host = false;
4698 }
4699
4700 // Implementing the async backend interfaces seems broken on older Intel HW,
4701 // see https://github.com/ggml-org/llama.cpp/issues/17302.
4702 device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
4703 std::string(device->properties.deviceName.data()).find("(DG1)") == std::string::npos) &&
4704 getenv("GGML_VK_DISABLE_ASYNC") == nullptr;
4705
4706 if (!device->support_async) {
4707 GGML_LOG_DEBUG("ggml_vulkan: WARNING: Async execution disabled on certain Intel devices.\n");
4708 }
4709
4710 const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
4711
4712 if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
4713 device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
4714 } else if (maintenance4_support) {
4715 device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
4716 } else {
4717 device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
4718 }
4719
4720 const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE");
4721
4722 if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) {
4723 device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE);
4724 } else if (maintenance4_support) {
4725 device->max_buffer_size = props4.maxBufferSize;
4726 } else {
4727 device->max_buffer_size = device->max_memory_allocation_size;
4728 }
4729
4730 const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
4731
4732 if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
4733 device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
4734 } else {
4735 // Limit batching of allocations to 1GB by default to avoid fragmentation issues
4736 device->suballocation_block_size = 1024*1024*1024;
4737 }
4738 device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
4739
4740 device->subgroup_size = subgroup_props.subgroupSize;
4741 device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
4742 device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
4743 if (sm_builtins) {
4744 device->shader_core_count = sm_props.shaderSMCount;
4745 } else if (amd_shader_core_properties2) {
4746 device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
4747 } else {
4748 device->shader_core_count = 0;
4749 }
4750 device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
4751
4752 device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4753 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic);
4754 device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4755 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
4756#ifdef __APPLE__
4757 // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846)
4758 if (device->vendor_id == VK_VENDOR_ID_AMD) {
4759 device->subgroup_arithmetic = false;
4760 }
4761#endif
4762 device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4763 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
4764 device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4765 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
4766
4767 device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4768 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
4769
4770 device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4771 (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
4772
4773 const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
4774
4775 device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
4776
4777 if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
4778 device->coopmat_support = false;
4779 }
4780
4781 device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
4782
4783 device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
4784
4785 device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
4786
4787 std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
4788
4789 // Try to find a non-graphics compute queue and transfer-focused queues
4790 const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
4791 const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
4792
4793 const float priorities[] = { 1.0f, 1.0f };
4794 device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
4795
4796 std::vector<vk::DeviceQueueCreateInfo> device_queue_create_infos;
4797 if (compute_queue_family_index != transfer_queue_family_index) {
4798 device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
4799 device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1});
4800 } else if(!device->single_queue) {
4801 device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities});
4802 } else {
4803 device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
4804 }
4805 vk::DeviceCreateInfo device_create_info;
4806 std::vector<const char *> device_extensions;
4807 vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
4808
4809 VkPhysicalDeviceFeatures2 device_features2;
4810 device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
4811 device_features2.pNext = nullptr;
4812 device_features2.features = (VkPhysicalDeviceFeatures)device_features;
4813
4814 VkPhysicalDeviceVulkan11Features vk11_features;
4815 vk11_features.pNext = nullptr;
4816 vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
4817 device_features2.pNext = &vk11_features;
4818
4819 VkPhysicalDeviceVulkan12Features vk12_features;
4820 vk12_features.pNext = nullptr;
4821 vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
4822 vk11_features.pNext = &vk12_features;
4823
4824 last_struct = (VkBaseOutStructure *)&vk12_features;
4825
4826 VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
4827 pl_robustness_features.pNext = nullptr;
4828 pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
4829 pl_robustness_features.pipelineRobustness = VK_FALSE;
4830
4831 if (pipeline_robustness) {
4832 last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
4833 last_struct = (VkBaseOutStructure *)&pl_robustness_features;
4834 device_extensions.push_back("VK_EXT_pipeline_robustness");
4835 }
4836
4837 VkPhysicalDeviceMemoryPriorityFeaturesEXT memory_priority_features;
4838 memory_priority_features.pNext = nullptr;
4839 memory_priority_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MEMORY_PRIORITY_FEATURES_EXT;
4840 memory_priority_features.memoryPriority = VK_FALSE;
4841 if (device->memory_priority) {
4842 last_struct->pNext = (VkBaseOutStructure *)&memory_priority_features;
4843 last_struct = (VkBaseOutStructure *)&memory_priority_features;
4844 device_extensions.push_back("VK_EXT_memory_priority");
4845 }
4846
4847 VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
4848 subgroup_size_control_features.pNext = nullptr;
4849 subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
4850 subgroup_size_control_features.computeFullSubgroups = false;
4851 subgroup_size_control_features.subgroupSizeControl = false;
4852
4853 if (device->subgroup_size_control) {
4854 last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
4855 last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
4856 }
4857
4858#if defined(VK_KHR_cooperative_matrix)
4859 VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
4860 coopmat_features.pNext = nullptr;
4861 coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
4862 coopmat_features.cooperativeMatrix = VK_FALSE;
4863
4864 if (device->coopmat_support) {
4865 last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
4866 last_struct = (VkBaseOutStructure *)&coopmat_features;
4867 }
4868#endif
4869
4870#if defined(VK_NV_cooperative_matrix2)
4871 VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
4872 coopmat2_features.pNext = nullptr;
4873 coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
4874 if (coopmat2_support) {
4875 last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
4876 last_struct = (VkBaseOutStructure *)&coopmat2_features;
4877 device_extensions.push_back("VK_NV_cooperative_matrix2");
4878 }
4879#endif
4880
4881#if defined(VK_KHR_shader_bfloat16)
4882 VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
4883 bfloat16_features.pNext = nullptr;
4884 bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
4885 if (bfloat16_support) {
4886 last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
4887 last_struct = (VkBaseOutStructure *)&bfloat16_features;
4888 device_extensions.push_back("VK_KHR_shader_bfloat16");
4889 }
4890#endif
4891
4892 VkPhysicalDeviceMaintenance4Features maint4_features {};
4893 maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
4894 if (maintenance4_support) {
4895 last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
4896 last_struct = (VkBaseOutStructure *)&maint4_features;
4897 device_extensions.push_back("VK_KHR_maintenance4");
4898 }
4899
4900 VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
4901 shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
4902 if (device->integer_dot_product) {
4903 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
4904 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
4905 device_extensions.push_back("VK_KHR_shader_integer_dot_product");
4906 }
4907
4908 VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
4909 pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
4910 if (pipeline_executable_properties_support) {
4911 last_struct->pNext = (VkBaseOutStructure *)&pep_features;
4912 last_struct = (VkBaseOutStructure *)&pep_features;
4913 device_extensions.push_back("VK_KHR_pipeline_executable_properties");
4914 }
4915
4916 if (device->external_memory_host) {
4917 device_extensions.push_back("VK_EXT_external_memory_host");
4918 }
4919
4920#if defined(VK_EXT_shader_64bit_indexing)
4921 VkPhysicalDeviceShader64BitIndexingFeaturesEXT shader_64bit_indexing_features {};
4922 shader_64bit_indexing_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_64_BIT_INDEXING_FEATURES_EXT;
4923 if (device->shader_64b_indexing) {
4924 last_struct->pNext = (VkBaseOutStructure *)&shader_64bit_indexing_features;
4925 last_struct = (VkBaseOutStructure *)&shader_64bit_indexing_features;
4926 device_extensions.push_back("VK_EXT_shader_64bit_indexing");
4927 }
4928#endif
4929
4930 vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
4931
4932 device->pipeline_executable_properties_support = pipeline_executable_properties_support;
4933
4934 device->fp16 = device->fp16 && vk12_features.shaderFloat16;
4935
4936#if defined(VK_KHR_shader_bfloat16)
4937 device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
4938#else
4939 device->bf16 = false;
4940#endif
4941
4942 device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
4943
4944 device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
4945 device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
4946 getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
4947
4948 device->shader_int64 = device_features2.features.shaderInt64;
4949 device->buffer_device_address = vk12_features.bufferDeviceAddress;
4950 device->vulkan_memory_model = vk12_features.vulkanMemoryModel;
4951
4952 if (device->subgroup_size_control) {
4953 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
4954 device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
4955 device_extensions.push_back("VK_EXT_subgroup_size_control");
4956 }
4957
4958 device->subgroup_size_control = device->subgroup_size_control &&
4959 (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
4960 subgroup_size_control_features.subgroupSizeControl;
4961
4962 device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
4963
4964#if defined(VK_KHR_cooperative_matrix)
4965 device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
4966
4967 // coopmat1 fa shader currently assumes 32 invocations per subgroup
4968 device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
4969 device->subgroup_size_control && device->subgroup_min_size <= 32 &&
4970 device->subgroup_max_size >= 32;
4971#endif
4972
4973 if (coopmat2_support) {
4974#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
4975 if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
4976 coopmat2_features.cooperativeMatrixFlexibleDimensions &&
4977 coopmat2_features.cooperativeMatrixReductions &&
4978 coopmat2_features.cooperativeMatrixConversions &&
4979 coopmat2_features.cooperativeMatrixPerElementOperations &&
4980 coopmat2_features.cooperativeMatrixTensorAddressing &&
4981 coopmat2_features.cooperativeMatrixBlockLoads &&
4982 vk12_features.bufferDeviceAddress) {
4983
4984 std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
4985 uint32_t count = 0;
4986
4987 PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
4988 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
4989 (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
4990 vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
4991
4992 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
4993
4994 VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
4995 empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
4996 flexible_dimensions.resize(count, empty_prop);
4997
4998 _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
4999
5000 bool found_fp16_128 = false,
5001 found_fp16_256 = false,
5002 found_fp32_128 = false,
5003 found_fp32_256 = false;
5004 // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
5005 // with 32x16x16 and 256 with 32x32x16.
5006 for (auto &prop : flexible_dimensions) {
5007 if (prop.saturatingAccumulation == VK_FALSE &&
5008 prop.scope == VK_SCOPE_WORKGROUP_KHR &&
5009 prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5010 prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5011
5012 if (prop.workgroupInvocations == 128 &&
5013 prop.MGranularity <= 32 &&
5014 prop.NGranularity <= 16 &&
5015 prop.KGranularity <= 16) {
5016 if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5017 prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5018 found_fp16_128 = true;
5019 }
5020 if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5021 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5022 found_fp32_128 = true;
5023 }
5024 }
5025 if (prop.workgroupInvocations == 256 &&
5026 prop.MGranularity <= 32 &&
5027 prop.NGranularity <= 32 &&
5028 prop.KGranularity <= 16) {
5029 if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5030 prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5031 found_fp16_256 = true;
5032 }
5033 if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5034 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5035 found_fp32_256 = true;
5036 }
5037 }
5038 }
5039 }
5040 if (found_fp16_128 && found_fp16_256 &&
5041 found_fp32_128 && found_fp32_256 &&
5042 coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
5043 device->coopmat2 = true;
5044 }
5045 }
5046#endif
5047 }
5048
5049 if (!vk11_features.storageBuffer16BitAccess) {
5050 std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
5051 throw std::runtime_error("Unsupported device");
5052 }
5053
5054 device_extensions.push_back("VK_KHR_16bit_storage");
5055
5056#ifdef GGML_VULKAN_VALIDATE
5057 device_extensions.push_back("VK_KHR_shader_non_semantic_info");
5058#endif
5059
5060 if (device->fp16) {
5061 device_extensions.push_back("VK_KHR_shader_float16_int8");
5062 }
5063
5064#if defined(VK_KHR_cooperative_matrix)
5065 if (device->coopmat_support) {
5066 // Query supported shapes
5067 std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
5068
5069 PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
5070 (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
5071
5072 uint32_t cm_props_num;
5073
5074 pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
5075
5076 cm_props.resize(cm_props_num);
5077
5078 for (auto& prop : cm_props) {
5079 prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
5080 }
5081
5082 pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
5083
5084 VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
5085
5086 for (auto& prop : cm_props) {
5087 VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
5088
5089 if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
5090 (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
5091 (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
5092 ) {
5093 if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
5094 (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
5095 // coopmat sizes not set yet
5096 if (device->coopmat_m == 0) {
5097 device->coopmat_acc_f32_support = true;
5098 device->coopmat_m = prop.MSize;
5099 device->coopmat_n = prop.NSize;
5100 device->coopmat_k = prop.KSize;
5101 } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
5102 // Only enable if shape is identical
5103 device->coopmat_acc_f32_support = true;
5104 }
5105 if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
5106 device->coopmat_support_16x16x16_f32acc = true;
5107 }
5108 } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
5109 (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
5110 // coopmat sizes not set yet
5111 if (device->coopmat_m == 0) {
5112 device->coopmat_acc_f16_support = true;
5113 device->coopmat_m = prop.MSize;
5114 device->coopmat_n = prop.NSize;
5115 device->coopmat_k = prop.KSize;
5116 } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
5117 // Only enable if shape is identical
5118 device->coopmat_acc_f16_support = true;
5119 }
5120 if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
5121 device->coopmat_support_16x16x16_f16acc = true;
5122 }
5123 }
5124 } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
5125 (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
5126 (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
5127 (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
5128 (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
5129 device->coopmat_int_m == 0
5130 ) {
5131 device->coopmat_int_support = true;
5132 device->coopmat_int_m = prop.MSize;
5133 device->coopmat_int_n = prop.NSize;
5134 device->coopmat_int_k = prop.KSize;
5135 }
5136#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
5137 if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
5138 prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
5139 prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5140 prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5141 (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
5142 ) {
5143 // coopmat sizes not set yet
5144 if (device->coopmat_m == 0) {
5145 device->coopmat_bf16_support = true;
5146 device->coopmat_m = prop.MSize;
5147 device->coopmat_n = prop.NSize;
5148 device->coopmat_k = prop.KSize;
5149 } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
5150 // Only enable if shape is identical
5151 device->coopmat_bf16_support = true;
5152 }
5153 }
5154#endif
5155 }
5156
5157 if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
5158 // No suitable matmul mode found
5159 GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
5160 device->coopmat_support = false;
5161 }
5162 if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
5163 device->coopmat_bf16_support = false;
5164 }
5165 }
5166
5167 if (device->coopmat_support) {
5168 device_extensions.push_back("VK_KHR_cooperative_matrix");
5169 }
5170#if defined(VK_KHR_shader_bfloat16)
5171 if (device->coopmat_bf16_support) {
5172 device_extensions.push_back("VK_KHR_shader_bfloat16");
5173 }
5174#endif
5175#endif
5176 device->name = GGML_VK_NAME + std::to_string(idx);
5177
5178 device_create_info = {
5179 vk::DeviceCreateFlags(),
5180 device_queue_create_infos,
5181 {},
5182 device_extensions
5183 };
5184 device_create_info.setPNext(&device_features2);
5185 device->device = device->physical_device.createDevice(device_create_info);
5186
5187 // Queues
5188 ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
5189
5190 // Shaders
5191 // Disable matmul tile sizes early if performance low or not supported
5192 for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) {
5193 switch (device->vendor_id) {
5194#ifndef GGML_VULKAN_RUN_TESTS
5195 case VK_VENDOR_ID_AMD:
5196 device->mul_mat_l[i] = device->coopmat_support && device->driver_id != vk::DriverId::eAmdProprietary;
5197 device->mul_mat_m[i] = true;
5198 device->mul_mat_s[i] = true;
5199 device->mul_mat_id_l[i] = false;
5200 device->mul_mat_id_m[i] = true;
5201 device->mul_mat_id_s[i] = true;
5202 break;
5203 case VK_VENDOR_ID_INTEL:
5204 if (!device->coopmat_support || device->architecture != INTEL_XE2) {
5205 device->mul_mat_l[i] = false;
5206 device->mul_mat_id_l[i] = false;
5207 } else {
5208 device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel
5209 device->mul_mat_id_l[i] = true;
5210 }
5211 device->mul_mat_m[i] = true;
5212 device->mul_mat_s[i] = true;
5213 device->mul_mat_id_m[i] = true;
5214 device->mul_mat_id_s[i] = true;
5215 break;
5216 case VK_VENDOR_ID_APPLE:
5217 device->mul_mat_l[i] = false;
5218 device->mul_mat_m[i] = true;
5219 device->mul_mat_s[i] = false;
5220 device->mul_mat_id_l[i] = false;
5221 device->mul_mat_id_m[i] = true;
5222 device->mul_mat_id_s[i] = false;
5223 break;
5224#endif
5225 default:
5226 device->mul_mat_l[i] = true;
5227 device->mul_mat_m[i] = true;
5228 device->mul_mat_s[i] = true;
5229 device->mul_mat_id_l[i] = true;
5230 device->mul_mat_id_m[i] = true;
5231 device->mul_mat_id_s[i] = true;
5232 break;
5233 }
5234 }
5235
5236
5237 std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
5238 std::vector<vk::DescriptorBindingFlags> dsl_binding_flags;
5239 for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) {
5240 dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute});
5241 dsl_binding_flags.push_back({});
5242 }
5243
5244 vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags };
5245
5246 vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info(
5247 {},
5248 dsl_binding);
5249 descriptor_set_layout_create_info.setPNext(&dslbfci);
5250 device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info);
5251
5252 ggml_vk_load_shaders(device);
5253
5254 if (!device->single_queue) {
5255 const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
5256 ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
5257 } else {
5258 // TODO: Use pointer or reference to avoid copy
5259 device->transfer_queue.copyFrom(device->compute_queue);
5260 device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
5261 }
5262
5263 device->buffer_type = {
5264 /* .iface = */ ggml_backend_vk_buffer_type_interface,
5265 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
5266 /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
5267 };
5268
5269 device->fence = device->device.createFence({});
5270
5271 device->idx = idx;
5272
5273 device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
5274
5275 device->add_rms_fusion = !device->disable_fusion &&
5276 device->subgroup_arithmetic &&
5277 device->vendor_id != VK_VENDOR_ID_INTEL;
5278 device->partials_binding_alignment =
5279 std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment);
5280
5281 device->mmvq_mode = 0;
5282 if (getenv("GGML_VK_DISABLE_MMVQ")) {
5283 device->mmvq_mode = -1;
5284 } else if (getenv("GGML_VK_FORCE_MMVQ")) {
5285 device->mmvq_mode = 1;
5286 }
5287
5288 return device;
5289 }
5290
5291 return vk_instance.devices[idx];
5292}
5293
5294static void ggml_vk_print_gpu_info(size_t idx) {
5295 GGML_ASSERT(idx < vk_instance.device_indices.size());
5296 size_t dev_num = vk_instance.device_indices[idx];
5297 VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")");
5298 GGML_ASSERT(vk_instance_initialized);
5299
5300 std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
5301
5302 if (dev_num >= devices.size()) {
5303 std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl;
5304 throw std::runtime_error("Device not found");
5305 }
5306
5307 vk::PhysicalDevice physical_device = devices[dev_num];
5308 std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
5309
5310 bool fp16_storage = false;
5311 bool fp16_compute = false;
5312 bool coopmat_support = false;
5313 bool coopmat2_support = false;
5314 bool integer_dot_product = false;
5315 bool bfloat16_support = false;
5316
5317 for (auto properties : ext_props) {
5318 if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
5319 fp16_storage = true;
5320 } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
5321 fp16_compute = true;
5322#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5323 } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
5324 !getenv("GGML_VK_DISABLE_COOPMAT")) {
5325 coopmat_support = true;
5326#endif
5327#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
5328 } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
5329 !getenv("GGML_VK_DISABLE_COOPMAT2")) {
5330 coopmat2_support = true;
5331#endif
5332#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
5333 } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
5334 !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
5335 integer_dot_product = true;
5336#endif
5337#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
5338 } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
5339 !getenv("GGML_VK_DISABLE_BFLOAT16")) {
5340 bfloat16_support = true;
5341#endif
5342 }
5343 }
5344
5345 const vk_device_architecture device_architecture = get_device_architecture(physical_device);
5346
5347 const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
5348 bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
5349
5350 bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
5351
5352 vk::PhysicalDeviceProperties2 props2;
5353 vk::PhysicalDeviceMaintenance3Properties props3;
5354 vk::PhysicalDeviceSubgroupProperties subgroup_props;
5355 vk::PhysicalDeviceDriverProperties driver_props;
5356 vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
5357 props2.pNext = &props3;
5358 props3.pNext = &subgroup_props;
5359 subgroup_props.pNext = &driver_props;
5360
5361 // Pointer to the last chain element
5362 VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
5363
5364 if (integer_dot_product) {
5365 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
5366 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
5367 }
5368
5369 physical_device.getProperties2(&props2);
5370
5371 VkPhysicalDeviceFeatures2 device_features2;
5372 device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
5373 device_features2.pNext = nullptr;
5374
5375 VkPhysicalDeviceVulkan11Features vk11_features;
5376 vk11_features.pNext = nullptr;
5377 vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
5378 device_features2.pNext = &vk11_features;
5379
5380 VkPhysicalDeviceVulkan12Features vk12_features;
5381 vk12_features.pNext = nullptr;
5382 vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
5383 vk11_features.pNext = &vk12_features;
5384
5385 // Pointer to the last chain element
5386 last_struct = (VkBaseOutStructure *)&vk12_features;
5387
5388#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5389 VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
5390 coopmat_features.pNext = nullptr;
5391 coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
5392 coopmat_features.cooperativeMatrix = VK_FALSE;
5393
5394 if (coopmat_support) {
5395 last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
5396 last_struct = (VkBaseOutStructure *)&coopmat_features;
5397 }
5398#endif
5399
5400 VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
5401 shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
5402 if (integer_dot_product) {
5403 last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
5404 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
5405 }
5406
5407#if defined(VK_KHR_shader_bfloat16)
5408 VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
5409 bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
5410 if (bfloat16_support) {
5411 last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
5412 last_struct = (VkBaseOutStructure *)&bfloat16_features;
5413 }
5414#endif
5415
5416 vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
5417
5418 fp16 = fp16 && vk12_features.shaderFloat16;
5419
5420#if defined(VK_KHR_shader_bfloat16)
5421 bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
5422#else
5423 bool bf16 = false;
5424#endif
5425
5426 uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
5427 const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
5428 const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
5429
5430 integer_dot_product = integer_dot_product
5431 && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
5432 && shader_integer_dot_product_features.shaderIntegerDotProduct;
5433
5434 coopmat_support = coopmat_support
5435#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5436 && coopmat_features.cooperativeMatrix
5437#endif
5438 && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
5439
5440 std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
5441
5442 std::string device_name = props2.properties.deviceName.data();
5443 GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
5444 idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
5445 props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
5446
5447 if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
5448 GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
5449 }
5450}
5451
5452static bool ggml_vk_instance_layer_settings_available();
5453static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions);
5454static bool ggml_vk_instance_debug_utils_ext_available(const std::vector<vk::ExtensionProperties> & instance_extensions);
5455static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev);
5456
5457static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance;
5458DispatchLoaderDynamic & ggml_vk_default_dispatcher() {
5459 return ggml_vk_default_dispatcher_instance;
5460}
5461
5462static void ggml_vk_instance_init() {
5463 if (vk_instance_initialized) {
5464 return;
5465 }
5466 VK_LOG_DEBUG("ggml_vk_instance_init()");
5467
5468 // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
5469 ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr);
5470
5471 uint32_t api_version = vk::enumerateInstanceVersion();
5472
5473 if (api_version < VK_API_VERSION_1_2) {
5474 std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
5475 throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required");
5476 }
5477
5478 vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
5479
5480 const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
5481 const bool layer_settings = ggml_vk_instance_layer_settings_available();
5482#ifdef __APPLE__
5483 const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions);
5484#endif
5485 const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr;
5486 std::vector<const char*> layers;
5487
5488 if (layer_settings) {
5489 layers.push_back("VK_LAYER_KHRONOS_validation");
5490 }
5491 std::vector<const char*> extensions;
5492 if (layer_settings) {
5493 extensions.push_back("VK_EXT_layer_settings");
5494 }
5495#ifdef __APPLE__
5496 if (portability_enumeration_ext) {
5497 extensions.push_back("VK_KHR_portability_enumeration");
5498 }
5499#endif
5500 if (debug_utils_ext) {
5501 extensions.push_back("VK_EXT_debug_utils");
5502 }
5503 VkBool32 enable_best_practice = layer_settings;
5504 std::vector<vk::LayerSettingEXT> settings = {
5505 {
5506 "VK_LAYER_KHRONOS_validation",
5507 "validate_best_practices",
5508 vk::LayerSettingTypeEXT::eBool32,
5509 1,
5510 &enable_best_practice
5511 },
5512 };
5513 vk::LayerSettingsCreateInfoEXT layer_setting_info(settings);
5514 vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions, &layer_setting_info);
5515#ifdef __APPLE__
5516 if (portability_enumeration_ext) {
5517 instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR;
5518 }
5519#endif
5520
5521 vk_instance.instance = vk::createInstance(instance_create_info);
5522 vk_instance_initialized = true;
5523
5524 if (debug_utils_ext) {
5525 vk_instance.debug_utils_support = true;
5526 vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT");
5527 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT");
5528 vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT");
5529 vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
5530 vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
5531 vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
5532 }
5533
5534 vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
5535 vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
5536 vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
5537 vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
5538 const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
5539
5540 if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
5541 vk_perf_logger_frequency = std::stoul(GGML_VK_PERF_LOGGER_FREQUENCY);
5542 }
5543
5544 // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers-
5545 VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance);
5546
5547 std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
5548
5549 // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
5550 char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
5551 if (devices_env != nullptr) {
5552 size_t num_available_devices = devices.size();
5553
5554 std::string devices(devices_env);
5555 std::replace(devices.begin(), devices.end(), ',', ' ');
5556
5557 std::stringstream ss(devices);
5558 size_t tmp;
5559 while (ss >> tmp) {
5560 if(tmp >= num_available_devices) {
5561 std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl;
5562 throw std::runtime_error("Invalid Vulkan device index");
5563 }
5564 vk_instance.device_indices.push_back(tmp);
5565 }
5566 } else {
5567 // If no vulkan devices are found, return early
5568 if (devices.empty()) {
5569 GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
5570 return;
5571 }
5572
5573 // Default to using all dedicated GPUs
5574 for (size_t i = 0; i < devices.size(); i++) {
5575 vk::PhysicalDeviceProperties2 new_props;
5576 vk::PhysicalDeviceDriverProperties new_driver;
5577 vk::PhysicalDeviceIDProperties new_id;
5578 new_props.pNext = &new_driver;
5579 new_driver.pNext = &new_id;
5580 devices[i].getProperties2(&new_props);
5581
5582 if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
5583 // Check if there are two physical devices corresponding to the same GPU
5584 // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
5585 // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
5586 // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
5587 // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
5588 // driver is MoltenVK
5589 auto old_device = std::find_if(
5590 vk_instance.device_indices.begin(),
5591 vk_instance.device_indices.end(),
5592 [&devices, &new_id, &new_driver](const size_t k){
5593 vk::PhysicalDeviceProperties2 old_props;
5594 vk::PhysicalDeviceDriverProperties old_driver;
5595 vk::PhysicalDeviceIDProperties old_id;
5596 old_props.pNext = &old_driver;
5597 old_driver.pNext = &old_id;
5598 devices[k].getProperties2(&old_props);
5599
5600 bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
5601 same_uuid = same_uuid || (
5602 old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
5603 std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
5604 );
5605 bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
5606
5607 return same_uuid && !both_molten_vk;
5608 }
5609 );
5610 if (old_device == vk_instance.device_indices.end()) {
5611 vk_instance.device_indices.push_back(i);
5612 } else {
5613 // There can be two physical devices corresponding to the same GPU if there are 2 different drivers
5614 // This can cause error when splitting layers aross the devices, need to keep only 1
5615 VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID");
5616
5617 vk::PhysicalDeviceProperties2 old_props;
5618 vk::PhysicalDeviceDriverProperties old_driver;
5619 old_props.pNext = &old_driver;
5620 devices[*old_device].getProperties2(&old_props);
5621
5622 std::map<vk::DriverId, int> driver_priorities {};
5623 int old_priority = std::numeric_limits<int>::max();
5624 int new_priority = std::numeric_limits<int>::max();
5625
5626 // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id
5627 // Smaller number -> higher priority
5628 switch (old_props.properties.vendorID) {
5629 case VK_VENDOR_ID_AMD:
5630 driver_priorities[vk::DriverId::eMesaRadv] = 1;
5631 driver_priorities[vk::DriverId::eAmdOpenSource] = 2;
5632 driver_priorities[vk::DriverId::eAmdProprietary] = 3;
5633 break;
5634 case VK_VENDOR_ID_INTEL:
5635 driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1;
5636 driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2;
5637 break;
5638 case VK_VENDOR_ID_NVIDIA:
5639 driver_priorities[vk::DriverId::eNvidiaProprietary] = 1;
5640#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235
5641 driver_priorities[vk::DriverId::eMesaNvk] = 2;
5642#endif
5643 break;
5644 }
5645 driver_priorities[vk::DriverId::eMesaDozen] = 100;
5646
5647 if (driver_priorities.count(old_driver.driverID)) {
5648 old_priority = driver_priorities[old_driver.driverID];
5649 }
5650 if (driver_priorities.count(new_driver.driverID)) {
5651 new_priority = driver_priorities[new_driver.driverID];
5652 }
5653
5654 if (new_priority < old_priority) {
5655 auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device);
5656 vk_instance.device_indices.erase(r, vk_instance.device_indices.end());
5657 vk_instance.device_indices.push_back(i);
5658
5659 VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName);
5660 }
5661 else {
5662 VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl);
5663 }
5664 }
5665 }
5666 }
5667
5668 // If no GPUs found, fall back to the first non-CPU device.
5669 // If only CPU devices are available, return without devices.
5670 if (vk_instance.device_indices.empty()) {
5671 for (size_t i = 0; i < devices.size(); i++) {
5672 if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) {
5673 vk_instance.device_indices.push_back(i);
5674 break;
5675 }
5676 }
5677 }
5678
5679 if (vk_instance.device_indices.empty()) {
5680 GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
5681 return;
5682 }
5683 }
5684 GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
5685
5686 for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
5687 vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
5688 std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
5689
5690 bool membudget_supported = false;
5691 for (const auto & ext : extensionprops) {
5692 if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) {
5693 membudget_supported = true;
5694 break;
5695 }
5696 }
5697
5698 vk_instance.device_supports_membudget.push_back(membudget_supported);
5699
5700 ggml_vk_print_gpu_info(i);
5701 }
5702}
5703
5704static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
5705 VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")");
5706 ggml_vk_instance_init();
5707 GGML_ASSERT(idx < vk_instance.device_indices.size());
5708
5709 ctx->name = GGML_VK_NAME + std::to_string(idx);
5710
5711 ctx->device = ggml_vk_get_device(idx);
5712
5713 ctx->semaphore_idx = 0;
5714 ctx->event_idx = 0;
5715
5716 ctx->prealloc_size_x = 0;
5717 ctx->prealloc_size_y = 0;
5718 ctx->prealloc_size_split_k = 0;
5719 // Fixed size of 1KB, for deterministic behavior
5720 ctx->prealloc_size_add_rms_partials = 1024;
5721
5722 ctx->fence = ctx->device->device.createFence({});
5723 ctx->almost_ready_fence = ctx->device->device.createFence({});
5724
5725 ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
5726
5727 if (vk_perf_logger_enabled) {
5728 ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
5729 }
5730
5731#ifdef GGML_VULKAN_CHECK_RESULTS
5732 const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
5733 vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks));
5734 const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR");
5735 vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor));
5736#endif
5737}
5738
5739static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) {
5740 VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
5741 switch (type) {
5742 case GGML_TYPE_F32:
5743 case GGML_TYPE_Q4_0:
5744 case GGML_TYPE_Q4_1:
5745 case GGML_TYPE_Q5_0:
5746 case GGML_TYPE_Q5_1:
5747 case GGML_TYPE_Q8_0:
5748 case GGML_TYPE_Q2_K:
5749 case GGML_TYPE_Q3_K:
5750 case GGML_TYPE_Q4_K:
5751 case GGML_TYPE_Q5_K:
5752 case GGML_TYPE_Q6_K:
5753 case GGML_TYPE_IQ1_S:
5754 case GGML_TYPE_IQ1_M:
5755 case GGML_TYPE_IQ2_XXS:
5756 case GGML_TYPE_IQ2_XS:
5757 case GGML_TYPE_IQ2_S:
5758 case GGML_TYPE_IQ3_XXS:
5759 case GGML_TYPE_IQ3_S:
5760 case GGML_TYPE_IQ4_XS:
5761 case GGML_TYPE_IQ4_NL:
5762 case GGML_TYPE_MXFP4:
5763 break;
5764 default:
5765 return nullptr;
5766 }
5767
5768 return ctx->device->pipeline_dequant[type];
5769}
5770
5771static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
5772 VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")");
5773 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
5774 return ctx->device->pipeline_matmul_f32;
5775 }
5776 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
5777 return ctx->device->pipeline_matmul_f32_f16;
5778 }
5779 if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
5780 return ctx->device->pipeline_matmul_bf16;
5781 }
5782 if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
5783 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5784 return ctx->device->pipeline_matmul_f16_f32.f16acc;
5785 }
5786 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5787 return ctx->device->pipeline_matmul_f16.f16acc;
5788 }
5789 } else {
5790 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5791 return ctx->device->pipeline_matmul_f16_f32.f32acc;
5792 }
5793 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5794 return ctx->device->pipeline_matmul_f16.f32acc;
5795 }
5796 }
5797
5798 // MMQ
5799 if (src1_type == GGML_TYPE_Q8_1) {
5800 vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc;
5801
5802 if (pipelines->is_empty()) {
5803 return nullptr;
5804 }
5805
5806 return pipelines;
5807 }
5808
5809 if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
5810 return nullptr;
5811 }
5812
5813 switch (src0_type) {
5814 case GGML_TYPE_Q4_0:
5815 case GGML_TYPE_Q4_1:
5816 case GGML_TYPE_Q5_0:
5817 case GGML_TYPE_Q5_1:
5818 case GGML_TYPE_Q8_0:
5819 case GGML_TYPE_Q2_K:
5820 case GGML_TYPE_Q3_K:
5821 case GGML_TYPE_Q4_K:
5822 case GGML_TYPE_Q5_K:
5823 case GGML_TYPE_Q6_K:
5824 case GGML_TYPE_IQ1_S:
5825 case GGML_TYPE_IQ1_M:
5826 case GGML_TYPE_IQ2_XXS:
5827 case GGML_TYPE_IQ2_XS:
5828 case GGML_TYPE_IQ2_S:
5829 case GGML_TYPE_IQ3_XXS:
5830 case GGML_TYPE_IQ3_S:
5831 case GGML_TYPE_IQ4_XS:
5832 case GGML_TYPE_IQ4_NL:
5833 case GGML_TYPE_MXFP4:
5834 break;
5835 default:
5836 return nullptr;
5837 }
5838
5839 if (ctx->device->coopmat2) {
5840 assert(src1_type == GGML_TYPE_F16);
5841 return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc;
5842 }
5843 if (ctx->device->coopmat_support) {
5844 return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
5845 }
5846 return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
5847}
5848
5849static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) {
5850 VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
5851 GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1);
5852 GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
5853
5854 if (b_type == GGML_TYPE_Q8_1) {
5855 switch (a_type) {
5856 case GGML_TYPE_Q4_0:
5857 case GGML_TYPE_Q4_1:
5858 case GGML_TYPE_Q5_0:
5859 case GGML_TYPE_Q5_1:
5860 case GGML_TYPE_Q8_0:
5861 case GGML_TYPE_MXFP4:
5862 case GGML_TYPE_Q2_K:
5863 case GGML_TYPE_Q3_K:
5864 case GGML_TYPE_Q4_K:
5865 case GGML_TYPE_Q5_K:
5866 case GGML_TYPE_Q6_K:
5867 case GGML_TYPE_IQ1_S:
5868 case GGML_TYPE_IQ1_M:
5869 break;
5870 default:
5871 return nullptr;
5872 }
5873 }
5874
5875 switch (a_type) {
5876 case GGML_TYPE_F32:
5877 case GGML_TYPE_F16:
5878 case GGML_TYPE_BF16:
5879 case GGML_TYPE_Q4_0:
5880 case GGML_TYPE_Q4_1:
5881 case GGML_TYPE_Q5_0:
5882 case GGML_TYPE_Q5_1:
5883 case GGML_TYPE_Q8_0:
5884 case GGML_TYPE_Q2_K:
5885 case GGML_TYPE_Q3_K:
5886 case GGML_TYPE_Q4_K:
5887 case GGML_TYPE_Q5_K:
5888 case GGML_TYPE_Q6_K:
5889 case GGML_TYPE_IQ1_S:
5890 case GGML_TYPE_IQ1_M:
5891 case GGML_TYPE_IQ2_XXS:
5892 case GGML_TYPE_IQ2_XS:
5893 case GGML_TYPE_IQ2_S:
5894 case GGML_TYPE_IQ3_XXS:
5895 case GGML_TYPE_IQ3_S:
5896 case GGML_TYPE_IQ4_XS:
5897 case GGML_TYPE_IQ4_NL:
5898 case GGML_TYPE_MXFP4:
5899 break;
5900 default:
5901 return nullptr;
5902 }
5903
5904 // heuristic to choose workgroup size
5905 uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
5906 if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
5907 // Prefer larger workgroups when M is small, to spread the work out more
5908 // and keep more SMs busy.
5909 // q6_k seems to prefer small workgroup size even for "medium" values of M.
5910 if (a_type == GGML_TYPE_Q6_K) {
5911 if (m < 4096 && k >= 1024) {
5912 dmmv_wg = DMMV_WG_SIZE_LARGE;
5913 }
5914 } else {
5915 if (m <= 8192 && k >= 1024) {
5916 dmmv_wg = DMMV_WG_SIZE_LARGE;
5917 }
5918 }
5919 }
5920
5921 if (b_type == GGML_TYPE_Q8_1) {
5922 if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
5923 dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
5924 }
5925 return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1];
5926 }
5927
5928 return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1];
5929}
5930
5931static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
5932 VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
5933 if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
5934 return ctx->device->pipeline_matmul_id_f32;
5935 }
5936 if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
5937 return ctx->device->pipeline_matmul_id_bf16;
5938 }
5939 if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
5940 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5941 return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
5942 }
5943 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5944 return ctx->device->pipeline_matmul_id_f16.f16acc;
5945 }
5946 } else {
5947 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
5948 return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
5949 }
5950 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
5951 return ctx->device->pipeline_matmul_id_f16.f32acc;
5952 }
5953 }
5954
5955 // MMQ
5956 if (src1_type == GGML_TYPE_Q8_1) {
5957 vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_id_q8_1[src0_type].f32acc;
5958
5959 if (pipelines->is_empty()) {
5960 return nullptr;
5961 }
5962
5963 return pipelines;
5964 }
5965
5966 GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
5967
5968 switch (src0_type) {
5969 case GGML_TYPE_Q4_0:
5970 case GGML_TYPE_Q4_1:
5971 case GGML_TYPE_Q5_0:
5972 case GGML_TYPE_Q5_1:
5973 case GGML_TYPE_Q8_0:
5974 case GGML_TYPE_Q2_K:
5975 case GGML_TYPE_Q3_K:
5976 case GGML_TYPE_Q4_K:
5977 case GGML_TYPE_Q5_K:
5978 case GGML_TYPE_Q6_K:
5979 case GGML_TYPE_IQ1_S:
5980 case GGML_TYPE_IQ1_M:
5981 case GGML_TYPE_IQ2_XXS:
5982 case GGML_TYPE_IQ2_XS:
5983 case GGML_TYPE_IQ2_S:
5984 case GGML_TYPE_IQ3_XXS:
5985 case GGML_TYPE_IQ3_S:
5986 case GGML_TYPE_IQ4_XS:
5987 case GGML_TYPE_IQ4_NL:
5988 case GGML_TYPE_MXFP4:
5989 break;
5990 default:
5991 return nullptr;
5992 }
5993
5994 vk_matmul_pipeline2& mmp = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
5995 // XXX TODO 'prec' is not actually allowed in mul_mat_id.
5996 bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/;
5997 bool support_fp16acc = !mmp.f16acc->is_empty();
5998 bool support_fp32acc = !mmp.f32acc->is_empty();
5999
6000 if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) {
6001 return mmp.f16acc;
6002 } else {
6003 GGML_ASSERT(support_fp32acc);
6004 return mmp.f32acc;
6005 }
6006}
6007
6008static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
6009 VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
6010 GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
6011
6012 if (b_type == GGML_TYPE_Q8_1) {
6013 switch (a_type) {
6014 case GGML_TYPE_Q4_0:
6015 case GGML_TYPE_Q4_1:
6016 case GGML_TYPE_Q5_0:
6017 case GGML_TYPE_Q5_1:
6018 case GGML_TYPE_Q8_0:
6019 case GGML_TYPE_MXFP4:
6020 case GGML_TYPE_Q2_K:
6021 case GGML_TYPE_Q3_K:
6022 case GGML_TYPE_Q4_K:
6023 case GGML_TYPE_Q5_K:
6024 case GGML_TYPE_Q6_K:
6025 case GGML_TYPE_IQ1_S:
6026 case GGML_TYPE_IQ1_M:
6027 break;
6028 default:
6029 return nullptr;
6030 }
6031 }
6032
6033 switch (a_type) {
6034 case GGML_TYPE_F32:
6035 case GGML_TYPE_F16:
6036 case GGML_TYPE_BF16:
6037 case GGML_TYPE_Q4_0:
6038 case GGML_TYPE_Q4_1:
6039 case GGML_TYPE_Q5_0:
6040 case GGML_TYPE_Q5_1:
6041 case GGML_TYPE_Q8_0:
6042 case GGML_TYPE_Q2_K:
6043 case GGML_TYPE_Q3_K:
6044 case GGML_TYPE_Q4_K:
6045 case GGML_TYPE_Q5_K:
6046 case GGML_TYPE_Q6_K:
6047 case GGML_TYPE_IQ1_S:
6048 case GGML_TYPE_IQ1_M:
6049 case GGML_TYPE_IQ2_XXS:
6050 case GGML_TYPE_IQ2_XS:
6051 case GGML_TYPE_IQ2_S:
6052 case GGML_TYPE_IQ3_XXS:
6053 case GGML_TYPE_IQ3_S:
6054 case GGML_TYPE_IQ4_XS:
6055 case GGML_TYPE_IQ4_NL:
6056 case GGML_TYPE_MXFP4:
6057 break;
6058 default:
6059 return nullptr;
6060 }
6061
6062 // heuristic to choose workgroup size
6063 uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
6064 if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
6065 // Prefer larger workgroups when M is small, to spread the work out more
6066 // and keep more SMs busy.
6067 // q6_k seems to prefer small workgroup size even for "medium" values of M.
6068 if (a_type == GGML_TYPE_Q6_K) {
6069 if (m < 4096 && k >= 1024) {
6070 dmmv_wg = DMMV_WG_SIZE_LARGE;
6071 }
6072 } else {
6073 if (m <= 8192 && k >= 1024) {
6074 dmmv_wg = DMMV_WG_SIZE_LARGE;
6075 }
6076 }
6077 }
6078
6079 if (b_type == GGML_TYPE_Q8_1) {
6080 if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
6081 dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
6082 }
6083 return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
6084 }
6085
6086 return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
6087}
6088
6089static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
6090 VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
6091 vk_buffer buf = ggml_vk_create_buffer(device, size,
6092 {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
6093 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
6094
6095 if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
6096 fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
6097 size/1024.0/1024.0);
6098 device->device.freeMemory(buf->device_memory);
6099 device->device.destroyBuffer(buf->buffer);
6100 return nullptr;
6101 }
6102
6103 std::lock_guard<std::recursive_mutex> guard(device->mutex);
6104 device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
6105
6106 return buf->ptr;
6107}
6108
6109static void ggml_vk_host_free(vk_device& device, void* ptr) {
6110 if (ptr == nullptr) {
6111 return;
6112 }
6113 VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
6114 std::lock_guard<std::recursive_mutex> guard(device->mutex);
6115
6116 vk_buffer buf;
6117 size_t index;
6118 for (size_t i = 0; i < device->pinned_memory.size(); i++) {
6119 const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
6120 const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
6121 if (ptr >= addr && ptr < endr) {
6122 buf = std::get<2>(device->pinned_memory[i]);
6123 index = i;
6124 break;
6125 }
6126 }
6127 if (buf == nullptr) {
6128 fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n");
6129 return;
6130 }
6131
6132 ggml_vk_destroy_buffer(buf);
6133
6134 device->pinned_memory.erase(device->pinned_memory.begin() + index);
6135}
6136
6137static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
6138 std::lock_guard<std::recursive_mutex> guard(device->mutex);
6139 buf = nullptr;
6140 buf_offset = 0;
6141 for (size_t i = 0; i < device->pinned_memory.size(); i++) {
6142 const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]);
6143 const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]);
6144 if (ptr >= addr && ptr < endr) {
6145 buf = std::get<2>(device->pinned_memory[i]);
6146 buf_offset = ((const uint8_t *)ptr) - addr;
6147 break;
6148 }
6149 }
6150}
6151
6152static vk_subbuffer ggml_vk_tensor_subbuffer(
6153 const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
6154
6155 vk_buffer buffer = nullptr;
6156 size_t offset = 0;
6157 if (ctx->device->uma) {
6158 ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
6159 }
6160 if (!buffer) {
6161 auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
6162 buffer = buf_ctx->dev_buffer;
6163 offset = vk_tensor_offset(tensor) + tensor->view_offs;
6164 }
6165 GGML_ASSERT(buffer != nullptr);
6166
6167 size_t size = ggml_nbytes(tensor);
6168
6169 size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
6170 // The shader must support misaligned offsets when indexing into the buffer
6171 GGML_ASSERT(allow_misalign || misalign_bytes == 0);
6172 offset &= ~misalign_bytes;
6173 size += misalign_bytes;
6174
6175 return vk_subbuffer{buffer, offset, size};
6176}
6177
6178static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
6179 vk_submission s;
6180 s.buffer = ggml_vk_create_cmd_buffer(device, p);
6181 if (one_time) {
6182 s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
6183 } else {
6184 s.buffer.begin({ vk::CommandBufferUsageFlags{} });
6185 }
6186
6187 return s;
6188}
6189
6190template <typename T> size_t push_constant_size(const T &t) {
6191 static_assert(std::is_class<T>::value, "T must be a struct/class");
6192 GGML_UNUSED(t);
6193 return sizeof(T);
6194}
6195template <typename T> size_t push_constant_size(const std::vector<T> &t) {
6196 GGML_UNUSED(t);
6197 return sizeof(T) * t.size();
6198}
6199template <typename T, uint32_t N> size_t push_constant_size(const std::array<T, N> &t) {
6200 GGML_UNUSED(t);
6201 return sizeof(T) * N;
6202}
6203
6204template <typename T> const T *push_constant_data(const T &t) {
6205 static_assert(std::is_class<T>::value, "T must be a struct/class");
6206 return &t;
6207}
6208template <typename T> const T *push_constant_data(const std::vector<T> &t) {
6209 return t.data();
6210}
6211template <typename T, uint32_t N> const T *push_constant_data(const std::array<T, N> &t) {
6212 return t.data();
6213}
6214
6215template <typename T>
6216static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, const T &push_constants, std::array<uint32_t, 3> elements) {
6217 const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
6218 const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
6219 const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]);
6220 VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {";
6221 for (auto& buffer : descriptor_buffer_infos) {
6222 std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), ";
6223 }
6224 std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
6225 GGML_ASSERT(wg0 <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
6226 wg1 <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
6227 wg2 <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
6228 GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
6229 GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
6230 GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
6231 GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants));
6232
6233 vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
6234 vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
6235 ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
6236
6237 subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
6238 subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
6239 subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
6240 pipeline->layout,
6241 0,
6242 { descriptor_set },
6243 {});
6244 subctx->s->buffer.dispatch(wg0, wg1, wg2);
6245}
6246
6247static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
6248 s.buffer.end();
6249
6250 s.wait_semaphores = std::move(wait_semaphores);
6251 s.signal_semaphores = std::move(signal_semaphores);
6252}
6253
6254static void ggml_vk_ctx_end(vk_context& ctx) {
6255 VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
6256 if (ctx->s == nullptr) {
6257 return;
6258 }
6259
6260 ctx->s->buffer.end();
6261 ctx->s = nullptr;
6262}
6263
6264static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
6265 VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")");
6266 if (subctx->s != nullptr) {
6267 ggml_vk_ctx_end(subctx);
6268 }
6269
6270 subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) });
6271 subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
6272}
6273
6274static size_t ggml_vk_align_size(size_t width, size_t align) {
6275 VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
6276 return CEIL_DIV(width, align) * align;
6277}
6278
6279static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector<vk_staging_memcpy>* memcpys = nullptr) {
6280 if (memcpys == nullptr) {
6281 memcpy(dst, src, size);
6282 } else {
6283 memcpys->emplace_back(dst, src, size);
6284 }
6285}
6286
6287static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {
6288 if (memsets == nullptr) {
6289 memset(dst, val, size);
6290 } else {
6291 memsets->emplace_back(dst, val, size);
6292 }
6293}
6294
6295static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
6296 if (device->sync_staging == nullptr || device->sync_staging->size < size) {
6297 VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
6298 ggml_vk_destroy_buffer(device->sync_staging);
6299 device->sync_staging = ggml_vk_create_buffer_check(device, size,
6300 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
6301 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
6302 }
6303}
6304
6305static void ggml_vk_ensure_sync_staging_buffer(ggml_backend_vk_context * ctx, size_t size) {
6306 if (ctx->sync_staging == nullptr || ctx->sync_staging->size < size) {
6307 VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
6308 ggml_vk_destroy_buffer(ctx->sync_staging);
6309 ctx->sync_staging = ggml_vk_create_buffer_check(ctx->device, size,
6310 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
6311 vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
6312 }
6313}
6314
6315static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) {
6316 VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")");
6317 GGML_ASSERT(!ggml_is_contiguous(tensor));
6318 // Buffer is already mapped
6319 if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
6320 std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl;
6321 GGML_ABORT("fatal error");
6322 }
6323 // Check if src is pinned memory
6324 vk_buffer buf = nullptr;
6325 size_t buf_offset = 0;
6326 ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
6327
6328 const uint64_t ne0 = tensor->ne[0];
6329 const uint64_t ne1 = tensor->ne[1];
6330 const uint64_t ne2 = tensor->ne[2];
6331 const uint64_t ne3 = tensor->ne[3];
6332 const uint64_t nb0 = tensor->nb[0];
6333 const uint64_t nb1 = tensor->nb[1];
6334 const uint64_t nb2 = tensor->nb[2];
6335 const uint64_t nb3 = tensor->nb[3];
6336 const ggml_type type = tensor->type;
6337 const uint64_t ts = ggml_type_size(type);
6338 const uint64_t bs = ggml_blck_size(type);
6339
6340 const uint64_t dstnb0 = ts;
6341 const uint64_t dstnb1 = dstnb0*(ne0/bs);
6342 const uint64_t dstnb2 = dstnb1*ne1;
6343 const uint64_t dstnb3 = dstnb2*ne2;
6344
6345 const uint64_t ne = ggml_nelements(tensor);
6346
6347 if (buf != nullptr) {
6348 // Memory is pinned, use as staging buffer
6349 std::vector<vk::BufferCopy> slices;
6350
6351 for (uint64_t i3 = 0; i3 < ne3; i3++) {
6352 for (uint64_t i2 = 0; i2 < ne2; i2++) {
6353 // Find longest contiguous slice
6354 if (ne1*nb1 == dstnb2) {
6355 slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 });
6356 } else {
6357 for (uint64_t i1 = 0; i1 < ne1; i1++) {
6358 if (ne0*nb0/bs == dstnb1) {
6359 slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 });
6360 } else {
6361 const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
6362 const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
6363 for (uint64_t i0 = 0; i0 < ne0; i0++) {
6364 slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
6365 }
6366 }
6367 }
6368 }
6369 }
6370 }
6371
6372 ggml_vk_sync_buffers(ctx, subctx);
6373 subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
6374 return;
6375 }
6376
6377 if (!sync_staging) {
6378 GGML_ABORT("Asynchronous write to non-pinned memory not supported");
6379 }
6380
6381 // Staging buffer required
6382 vk_buffer& staging = ctx->device->sync_staging;
6383 const uint64_t copy_size = ts*ne/bs;
6384 ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size);
6385 VkBufferCopy buf_copy{ 0, offset, copy_size };
6386
6387 ggml_vk_sync_buffers(ctx, subctx);
6388 vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
6389
6390 for (uint64_t i3 = 0; i3 < ne3; i3++) {
6391 for (uint64_t i2 = 0; i2 < ne2; i2++) {
6392 // Find longest contiguous slice
6393 if (ne1*nb1 == dstnb2) {
6394 deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys);
6395 } else {
6396 for (uint64_t i1 = 0; i1 < ne1; i1++) {
6397 if (ne0*nb0/bs == dstnb1) {
6398 deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys);
6399 } else {
6400 const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
6401 const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
6402 for (uint64_t i0 = 0; i0 < ne0; i0++) {
6403 deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys);
6404 }
6405 }
6406 }
6407 }
6408 }
6409 }
6410}
6411
6412static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
6413 VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
6414 // Check if src is pinned memory
6415 vk_buffer buf = nullptr;
6416 size_t buf_offset = 0;
6417 ggml_vk_host_get(dst->device, src, buf, buf_offset);
6418
6419 if (buf != nullptr) {
6420 // Memory is pinned, use as staging buffer
6421 std::vector<vk::BufferCopy> slices(1);
6422 if (width == spitch) {
6423 // Only do single write if stride is equal
6424 slices[0].srcOffset = buf_offset;
6425 slices[0].dstOffset = offset;
6426 slices[0].size = width * height;
6427 } else {
6428 slices.resize(height);
6429 for (size_t i = 0; i < height; i++) {
6430 slices[i].srcOffset = buf_offset + i * spitch;
6431 slices[i].dstOffset = offset + i * width;
6432 slices[i].size = width;
6433 }
6434 }
6435
6436 ggml_vk_sync_buffers(nullptr, subctx);
6437 subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
6438 return true;
6439 }
6440 VK_LOG_DEBUG("STAGING");
6441
6442 if (!sync_staging) {
6443 // copy was not handled caller needs to fall back
6444 return false;
6445 }
6446
6447 // Staging buffer required
6448 const size_t copy_size = width*height;
6449 ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
6450
6451 vk_buffer& staging_buffer = dst->device->sync_staging;
6452
6453 VkBufferCopy buf_copy = {
6454 0,
6455 offset,
6456 copy_size};
6457
6458 ggml_vk_sync_buffers(nullptr, subctx);
6459 vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
6460
6461 if (width == spitch) {
6462 deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
6463 } else {
6464 for (size_t i = 0; i < height; i++) {
6465 deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
6466 }
6467 }
6468 return true;
6469}
6470
6471static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
6472 VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
6473 return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
6474}
6475
6476static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
6477 VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
6478 // Buffer is already mapped
6479 if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
6480 GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
6481
6482 for (size_t i = 0; i < height; i++) {
6483 memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
6484 }
6485 } else {
6486 std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
6487
6488 vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
6489 ggml_vk_ctx_begin(dst->device, subctx);
6490 bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
6491 GGML_ASSERT(ret);
6492 ggml_vk_ctx_end(subctx);
6493
6494 for (auto& cpy : subctx->in_memcpys) {
6495 memcpy(cpy.dst, cpy.src, cpy.n);
6496 }
6497
6498 for (auto& mset : subctx->memsets) {
6499 memset(mset.dst, mset.val, mset.n);
6500 }
6501
6502 ggml_vk_submit(subctx, dst->device->fence);
6503 VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
6504 dst->device->device.resetFences({ dst->device->fence });
6505 ggml_vk_queue_command_pools_cleanup(dst->device);
6506 }
6507}
6508
6509static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
6510 VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
6511 ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
6512}
6513
6514static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
6515 VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")");
6516 GGML_ASSERT(width > 0);
6517 GGML_ASSERT(height > 0);
6518 GGML_ASSERT(src != nullptr);
6519
6520 // TODO: staging_offset is not used
6521
6522 // Check if dst is pinned memory
6523 vk_buffer buf = nullptr;
6524 size_t buf_offset = 0;
6525 ggml_vk_host_get(src->device, dst, buf, buf_offset);
6526
6527 std::vector<vk::BufferCopy> slices(1);
6528 if (width == spitch && width == dpitch) {
6529 // Only do single write if stride is equal
6530 slices[0].srcOffset = offset;
6531 slices[0].dstOffset = buf_offset;
6532 slices[0].size = width * height;
6533 } else {
6534 slices.resize(height);
6535 for (size_t i = 0; i < height; i++) {
6536 slices[i].srcOffset = offset + i * spitch;
6537 slices[i].dstOffset = buf_offset + i * dpitch;
6538 slices[i].size = width;
6539 }
6540 }
6541
6542 if (buf != nullptr) {
6543 // Memory is pinned, use as staging buffer
6544 ggml_vk_sync_buffers(nullptr, subctx);
6545 subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
6546
6547 return true;
6548 }
6549 VK_LOG_DEBUG("STAGING");
6550
6551 if (!sync_staging) {
6552 // copy was not handled caller needs to fall back
6553 return false;
6554 }
6555
6556 // Fall back to staging buffer
6557 const size_t copy_size = dpitch * height;
6558 ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
6559
6560 vk_buffer& staging_buffer = src->device->sync_staging;
6561
6562 ggml_vk_sync_buffers(nullptr, subctx);
6563 subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
6564
6565 deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
6566 return true;
6567}
6568
6569static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) {
6570 return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
6571}
6572
6573static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
6574 VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
6575
6576 // If the device is not an UMA device the memory is host-accessible through rebar. While writing
6577 // through PCIe is sufficient fast reading back data from PCIe is slower than going through
6578 // the HW device to host copy path.
6579 if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
6580 GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
6581
6582 memcpy(dst, (uint8_t *) src->ptr + offset, size);
6583 } else {
6584 std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
6585
6586 vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
6587 ggml_vk_ctx_begin(src->device, subctx);
6588 bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
6589 GGML_ASSERT(ret);
6590 ggml_vk_ctx_end(subctx);
6591
6592 ggml_vk_submit(subctx, src->device->fence);
6593 VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
6594 src->device->device.resetFences({ src->device->fence });
6595 ggml_vk_queue_command_pools_cleanup(src->device);
6596
6597 for (auto& cpy : subctx->out_memcpys) {
6598 memcpy(cpy.dst, cpy.src, cpy.n);
6599 }
6600 }
6601}
6602
6603static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
6604 VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
6605 // Make sure both buffers are on same device
6606 GGML_ASSERT(src->device == dst->device);
6607
6608 VkBufferCopy bc{ src_offset, dst_offset, size };
6609
6610 vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
6611}
6612
6613static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
6614 if (src->device == dst->device) {
6615 std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
6616 VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")");
6617 // Copy within the device
6618 vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
6619 ggml_vk_ctx_begin(src->device, subctx);
6620 ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size);
6621 ggml_vk_ctx_end(subctx);
6622 ggml_vk_submit(subctx, src->device->fence);
6623 VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences");
6624 src->device->device.resetFences({ src->device->fence });
6625 ggml_vk_queue_command_pools_cleanup(src->device);
6626 } else {
6627 VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")");
6628 // Copy device to device
6629 ggml_vk_ensure_sync_staging_buffer(src->device, size);
6630
6631 // Copy to src staging buffer
6632 ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
6633 // Copy to dst buffer
6634 ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
6635 }
6636}
6637
6638static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
6639 VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
6640
6641 if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
6642 dst->device->uma) {
6643 deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
6644 return;
6645 }
6646
6647 // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
6648 ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
6649}
6650
6651static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
6652 VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
6653
6654 if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
6655 dst->device->uma) {
6656 memset((uint8_t*)dst->ptr + offset, c, size);
6657 return;
6658 }
6659
6660 std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
6661 vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
6662 ggml_vk_ctx_begin(dst->device, subctx);
6663 subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
6664 ggml_vk_ctx_end(subctx);
6665
6666 ggml_vk_submit(subctx, dst->device->fence);
6667 VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences");
6668 dst->device->device.resetFences({ dst->device->fence });
6669 ggml_vk_queue_command_pools_cleanup(dst->device);
6670}
6671
6672static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) {
6673 VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")");
6674
6675 if (disable_split_k) {
6676 return 1;
6677 }
6678
6679 uint32_t split_k = 1;
6680 if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) {
6681 // If k is 'large' and the SMs will fill less than halfway, use split_k.
6682 uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
6683 uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
6684
6685 if (k >= 2048) {
6686 if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) {
6687 split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
6688 } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) {
6689 split_k = 3;
6690 }
6691 // Cap the split at 8x. Unless k is huge this is a lot of overhead.
6692 split_k = std::min(split_k, 8u);
6693
6694 // ggml_vk_matmul will align the splits to be a multiple of 256.
6695 // If this rounded up size would cause the last split to be empty,
6696 // then reduce the split count.
6697 while (true) {
6698 if (split_k == 1) {
6699 break;
6700 }
6701 uint32_t k_split = CEIL_DIV(k, split_k);
6702 k_split = ROUNDUP_POW2(k_split, 256);
6703 if (k_split * (split_k - 1) < k) {
6704 break;
6705 }
6706 split_k--;
6707 }
6708 }
6709 }
6710
6711 return split_k;
6712}
6713
6714static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
6715 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
6716
6717 if (ctx->device->coopmat2) {
6718 const uint32_t shader_core_count = ctx->device->shader_core_count;
6719 const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
6720 const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]);
6721
6722 // Use large shader when the N dimension is greater than the medium shader's tile size
6723 uint32_t crossover_large = mmp->m->wg_denoms[1];
6724
6725 // Prefer large over medium if either:
6726 // - medium or large tiles would overfill the GPU
6727 // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not
6728 // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead)
6729 bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count ||
6730 // split_k==3 with large tiles likely better than medium tiles with no split_k.
6731 (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
6732
6733 if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
6734 return aligned ? mmp->a_l : mmp->l;
6735 }
6736 // Use medium shader when the N dimension is greater than the small shader's tile size
6737 uint32_t crossover_medium = mmp->s->wg_denoms[1];
6738 if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
6739 return aligned ? mmp->a_m : mmp->m;
6740 }
6741 return aligned ? mmp->a_s : mmp->s;
6742 }
6743
6744 if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
6745 return aligned ? mmp->a_s : mmp->s;
6746 }
6747 if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
6748 return aligned ? mmp->a_m : mmp->m;
6749 }
6750 return aligned ? mmp->a_l : mmp->l;
6751
6752 GGML_UNUSED(src1_type);
6753}
6754
6755static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
6756 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
6757 return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
6758}
6759
6760static void ggml_vk_matmul(
6761 ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
6762 vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
6763 uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
6764 uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
6765 uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
6766 uint32_t padded_n) {
6767 VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
6768 if (split_k == 1) {
6769 const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
6770 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
6771 return;
6772 }
6773
6774 if (ctx->prealloc_split_k_need_sync) {
6775 ggml_vk_sync_buffers(ctx, subctx);
6776 }
6777
6778 GGML_ASSERT(batch_stride_d == m * n);
6779
6780 // Round the split size up to a multiple of 256 (k-quant alignment)
6781 uint32_t k_split = CEIL_DIV(k, split_k);
6782 k_split = ROUNDUP_POW2(k_split, 256);
6783
6784 const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
6785 // Make sure enough workgroups get assigned for split k to work
6786 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
6787 ggml_vk_sync_buffers(ctx, subctx);
6788 const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
6789 ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
6790 ctx->prealloc_split_k_need_sync = true;
6791}
6792
6793static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
6794 VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
6795
6796 if (ctx->device->coopmat2) {
6797 // Use large shader when the N dimension is greater than the medium shader's tile size
6798 uint32_t crossover_large = mmp->m->wg_denoms[1];
6799 if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
6800 return aligned ? mmp->a_l : mmp->l;
6801 }
6802 // Use medium shader when the N dimension is greater than the small shader's tile size
6803 uint32_t crossover_medium = mmp->s->wg_denoms[1];
6804 if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
6805 return aligned ? mmp->a_m : mmp->m;
6806 }
6807 return aligned ? mmp->a_s : mmp->s;
6808 }
6809
6810 if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
6811 return aligned ? mmp->a_s : mmp->s;
6812 }
6813 if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
6814 return aligned ? mmp->a_m : mmp->m;
6815 }
6816 return aligned ? mmp->a_l : mmp->l;
6817}
6818
6819static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
6820 VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
6821 return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
6822}
6823
6824static void ggml_vk_matmul_id(
6825 ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
6826 vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, const vk_subbuffer & expert_count_buf,
6827 uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
6828 uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
6829 uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
6830 uint32_t padded_n) {
6831 VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), expert_count: (" << expert_count_buf.buffer->buffer << ", " << expert_count_buf.offset << ", " << expert_count_buf.size << "), " <<
6832 "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
6833 "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
6834 "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
6835 const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
6836 nei0, nei1, nbi1, ne11, padded_n };
6837 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids, expert_count_buf }, pc, { m, nei1, n_as });
6838}
6839
6840static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
6841 return
6842 tensor->nb[0] == ggml_type_size(tensor->type) &&
6843 tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) &&
6844 (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]);
6845}
6846
6847static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
6848
6849 // Choose "contiguous copy" shader if src/dst are contiguous
6850 bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
6851
6852 // Use optimized "transpose" shader if src dim1 is the innermost dimension.
6853 bool transpose = dst && src->nb[1] == ggml_type_size(to) && ggml_are_same_shape(dst, src);
6854
6855 if (transpose && src->type == to) {
6856 if (ggml_type_size(to) == 4) {
6857 return ctx->device->pipeline_cpy_transpose_32;
6858 } else if (ggml_type_size(to) == 2) {
6859 return ctx->device->pipeline_cpy_transpose_16;
6860 }
6861 }
6862
6863 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
6864 if (contig) {
6865 return ctx->device->pipeline_contig_cpy_f32_f32;
6866 } else {
6867 return ctx->device->pipeline_cpy_f32_f32;
6868 }
6869 }
6870 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
6871 if (contig) {
6872 return ctx->device->pipeline_contig_cpy_f32_f16;
6873 } else {
6874 return ctx->device->pipeline_cpy_f32_f16;
6875 }
6876 }
6877 if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
6878 if (contig) {
6879 return ctx->device->pipeline_contig_cpy_f16_f16;
6880 } else {
6881 return ctx->device->pipeline_cpy_f16_f16;
6882 }
6883 }
6884 if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
6885 if (contig) {
6886 return ctx->device->pipeline_contig_cpy_f16_f32;
6887 } else {
6888 return ctx->device->pipeline_cpy_f16_f32;
6889 }
6890 }
6891 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
6892 if (contig) {
6893 return ctx->device->pipeline_contig_cpy_f32_bf16;
6894 } else {
6895 return ctx->device->pipeline_cpy_f32_bf16;
6896 }
6897 }
6898 if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
6899 if (contig) {
6900 return ctx->device->pipeline_contig_cpy_f32_i32;
6901 } else {
6902 return ctx->device->pipeline_cpy_f32_i32;
6903 }
6904 }
6905 if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {
6906 if (contig) {
6907 return ctx->device->pipeline_contig_cpy_i32_f32;
6908 } else {
6909 return ctx->device->pipeline_cpy_i32_f32;
6910 }
6911 }
6912 if (src->type == GGML_TYPE_F32) {
6913 switch (to) {
6914 case GGML_TYPE_Q4_0:
6915 case GGML_TYPE_Q4_1:
6916 case GGML_TYPE_Q5_0:
6917 case GGML_TYPE_Q5_1:
6918 case GGML_TYPE_Q8_0:
6919 case GGML_TYPE_IQ4_NL:
6920 return ctx->device->pipeline_cpy_f32_quant[to];
6921 default:
6922 break;
6923 }
6924 }
6925
6926 if (to == GGML_TYPE_F32) {
6927 switch (src->type) {
6928 case GGML_TYPE_Q4_0:
6929 case GGML_TYPE_Q4_1:
6930 case GGML_TYPE_Q5_0:
6931 case GGML_TYPE_Q5_1:
6932 case GGML_TYPE_Q8_0:
6933 case GGML_TYPE_IQ4_NL:
6934 return ctx->device->pipeline_cpy_quant_f32[src->type];
6935 default:
6936 break;
6937 }
6938 }
6939
6940 if (src->type == to) {
6941 // Copy two or four bytes at a time, depending on block size.
6942 // For quantized types, we scale by block size/type size. But
6943 // this path is also used for bf16->bf16 for example, where the
6944 // type size must be exactly 2 or 4.
6945 GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
6946 if ((ggml_type_size(src->type) % 4) == 0) {
6947 if (contig) {
6948 return ctx->device->pipeline_contig_cpy_f32_f32;
6949 } else {
6950 return ctx->device->pipeline_cpy_f32_f32;
6951 }
6952 } else {
6953 if (contig) {
6954 return ctx->device->pipeline_contig_cpy_f16_f16;
6955 } else {
6956 return ctx->device->pipeline_cpy_f16_f16;
6957 }
6958 }
6959 }
6960
6961 std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
6962 GGML_ABORT("fatal error");
6963}
6964
6965static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, const vk_subbuffer & in, const vk_subbuffer & out) {
6966 VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
6967 std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
6968 const int tensor_type_size = ggml_type_size(tensor->type);
6969
6970 const uint32_t ne = ggml_nelements(tensor);
6971 std::array<uint32_t, 3> elements;
6972
6973 if (ne > 262144) {
6974 elements = { 512, 512, CEIL_DIV(ne, 262144) };
6975 } else if (ne > 512) {
6976 elements = { 512, CEIL_DIV(ne, 512), 1 };
6977 } else {
6978 elements = { ne, 1, 1 };
6979 }
6980
6981 vk_op_unary_push_constants pc = {
6982 (uint32_t)ne,
6983 (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
6984 (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
6985 0,
6986 0.0f, 0.0f,
6987 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6988 };
6989 init_pushconst_fastdiv(pc);
6990 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
6991 ggml_vk_sync_buffers(ctx, subctx);
6992}
6993
6994static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
6995 switch(type) {
6996 case GGML_TYPE_Q8_1:
6997 return ctx->device->pipeline_quantize_q8_1_x4;
6998 default:
6999 std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
7000 GGML_ABORT("fatal error");
7001 }
7002}
7003
7004static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, const vk_subbuffer & in, const vk_subbuffer & out, uint32_t ne) {
7005 VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
7006
7007 vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
7008
7009 const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]);
7010 // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
7011 const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
7012 const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
7013
7014 const vk_quantize_q8_1_push_constants pc = {
7015 ne,
7016 num_blocks,
7017 };
7018
7019 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 });
7020 ggml_vk_sync_buffers(ctx, subctx);
7021}
7022
7023static vk_pipeline ggml_vk_get_64b_indexing_pipeline(ggml_backend_vk_context * ctx, vk_pipeline &pipeline) {
7024 GGML_UNUSED(ctx);
7025#if defined(VK_EXT_shader_64bit_indexing)
7026 vk_pipeline *ptr = &pipeline;
7027 while (*ptr) {
7028 if ((*ptr)->is_64b_indexing) {
7029 return *ptr;
7030 }
7031 ptr = &(*ptr)->next;
7032 }
7033#endif
7034 return pipeline;
7035}
7036
7037static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k) {
7038 VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7039 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7040 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7041 std::cerr << "))");
7042 GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
7043 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
7044
7045 const uint64_t ne00 = src0->ne[0];
7046 const uint64_t ne01 = src0->ne[1];
7047 const uint64_t ne02 = src0->ne[2];
7048 const uint64_t ne03 = src0->ne[3];
7049
7050 const uint64_t ne10 = src1->ne[0];
7051 const uint64_t ne11 = src1->ne[1];
7052 const uint64_t ne12 = src1->ne[2];
7053 const uint64_t ne13 = src1->ne[3];
7054
7055 const uint64_t ne21 = dst->ne[1];
7056 const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type);
7057 const uint32_t stride_batch_d = stride_d*ne21;
7058
7059 const uint64_t r2 = ne12 / ne02;
7060 const uint64_t r3 = ne13 / ne03;
7061
7062 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
7063 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
7064 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
7065
7066 vk_buffer d_Qx = nullptr;
7067 size_t qx_buf_offset = 0;
7068 vk_buffer d_Qy = nullptr;
7069 size_t qy_buf_offset = 0;
7070
7071 bool src0_uma = false;
7072 bool src1_uma = false;
7073
7074 if (ctx->device->uma) {
7075 ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
7076 ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
7077 src0_uma = d_Qx != nullptr;
7078 src1_uma = d_Qy != nullptr;
7079 }
7080
7081 // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
7082 const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
7083 !ggml_vk_dim01_contiguous(src0);
7084 const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
7085 (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
7086 !ggml_vk_dim01_contiguous(src1);
7087
7088 // If src0 is BF16, try to use a BF16 x BF16 multiply
7089 ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
7090
7091 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
7092
7093 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
7094
7095 // Check for mmq first
7096 vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
7097
7098 if (mmp == nullptr) {
7099 // Fall back to f16 dequant mul mat
7100 mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
7101 quantize_y = false;
7102 }
7103
7104 const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
7105 const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
7106
7107 if (qx_needs_dequant) {
7108 // Fall back to dequant + f16 mulmat
7109 mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
7110 }
7111
7112 // Not implemented
7113 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7114
7115 const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
7116 const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
7117
7118 vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
7119
7120 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
7121 pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
7122 }
7123
7124 // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
7125 uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
7126 const uint64_t x_ne = ggml_nelements(src0);
7127 // 128 elements per Q8_1 x4 block
7128 const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
7129 const uint64_t d_ne = ggml_nelements(dst);
7130
7131 const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline);
7132
7133 const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
7134 const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
7135 const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
7136 const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
7137 const uint64_t d_sz = sizeof(float) * d_ne;
7138
7139 vk_pipeline to_fp16_vk_0 = nullptr;
7140 vk_pipeline to_fp16_vk_1 = nullptr;
7141 vk_pipeline to_q8_1 = nullptr;
7142
7143 if (x_non_contig) {
7144 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
7145 } else {
7146 to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
7147 }
7148 if (y_non_contig) {
7149 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
7150 } else {
7151 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7152 }
7153 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7154 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7155
7156 if (quantize_y) {
7157 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
7158 }
7159
7160 {
7161 const uint64_t split_k_size = split_k > 1 ? d_sz * split_k : 0;
7162 if (
7163 (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
7164 (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
7165 (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) {
7166 GGML_ABORT("Requested preallocation size is too large");
7167 }
7168 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
7169 ctx->prealloc_size_x = x_sz;
7170 ggml_vk_preallocate_buffers(ctx, subctx);
7171 }
7172 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
7173 ctx->prealloc_size_y = y_sz;
7174 ggml_vk_preallocate_buffers(ctx, subctx);
7175 }
7176 if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
7177 ctx->prealloc_size_split_k = split_k_size;
7178 ggml_vk_preallocate_buffers(ctx, subctx);
7179 }
7180
7181 // Request descriptor sets
7182 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7183 if (qx_needs_dequant) {
7184 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
7185 }
7186 if (qy_needs_dequant) {
7187 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
7188 }
7189 if (quantize_y) {
7190 ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
7191 }
7192 if (split_k > 1) {
7193 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1);
7194 }
7195 }
7196
7197 vk_buffer d_D = dst_buf_ctx->dev_buffer;
7198 const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
7199 GGML_ASSERT(d_D != nullptr);
7200 GGML_ASSERT(d_D->size >= d_buf_offset + d_sz);
7201 vk_buffer d_X;
7202 uint64_t x_buf_offset = 0;
7203 vk_buffer d_Y;
7204 uint64_t y_buf_offset = 0;
7205 if (!src0_uma) {
7206 d_Qx = src0_buf_ctx->dev_buffer;
7207 qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
7208 GGML_ASSERT(d_Qx != nullptr);
7209 }
7210 if (!src1_uma) {
7211 d_Qy = src1_buf_ctx->dev_buffer;
7212 qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
7213 GGML_ASSERT(d_Qy != nullptr);
7214 }
7215 if (qx_needs_dequant) {
7216 d_X = ctx->prealloc_x;
7217 GGML_ASSERT(d_X->size >= x_sz);
7218 } else {
7219 d_X = d_Qx;
7220 x_buf_offset = qx_buf_offset;
7221 GGML_ASSERT(qx_sz == x_sz);
7222 }
7223 if (qy_needs_dequant) {
7224 d_Y = ctx->prealloc_y;
7225 GGML_ASSERT(d_Y->size >= y_sz);
7226 } else if (quantize_y) {
7227 d_Y = ctx->prealloc_y;
7228 GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);
7229 } else {
7230 d_Y = d_Qy;
7231 y_buf_offset = qy_buf_offset;
7232 GGML_ASSERT(qy_sz == y_sz);
7233 }
7234
7235 if (x_non_contig || qx_needs_dequant) {
7236 if (ctx->prealloc_x_need_sync) {
7237 ggml_vk_sync_buffers(ctx, subctx);
7238 }
7239 }
7240
7241 if (x_non_contig) {
7242 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
7243 } else if (qx_needs_dequant) {
7244 const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
7245 ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)(x_ne), 1, 1});
7246 ggml_vk_sync_buffers(ctx, subctx);
7247 }
7248 if (y_non_contig) {
7249 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
7250 ctx->prealloc_y_last_tensor_used != src1) {
7251 if (ctx->prealloc_y_need_sync) {
7252 ggml_vk_sync_buffers(ctx, subctx);
7253 }
7254 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
7255 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
7256 ctx->prealloc_y_last_tensor_used = src1;
7257 }
7258 }
7259 if (quantize_y) {
7260 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7261 ctx->prealloc_y_last_tensor_used != src1) {
7262 if (ctx->prealloc_y_need_sync) {
7263 ggml_vk_sync_buffers(ctx, subctx);
7264 }
7265 ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
7266 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7267 ctx->prealloc_y_last_tensor_used = src1;
7268 }
7269 }
7270
7271 uint32_t stride_batch_x = ne00*ne01;
7272 uint32_t stride_batch_y = ne10*ne11;
7273
7274 if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
7275 stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
7276 }
7277
7278 if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
7279 stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
7280 }
7281
7282 // compute
7283 ggml_vk_matmul(
7284 ctx, subctx, pipeline,
7285 { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
7286 ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * split_k },
7287 ne01, ne11, ne10,
7288 ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d,
7289 split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
7290 ); // NOLINT
7291
7292 if (x_non_contig || qx_needs_dequant) {
7293 ctx->prealloc_x_need_sync = true;
7294 }
7295 if (y_non_contig || quantize_y) {
7296 ctx->prealloc_y_need_sync = true;
7297 }
7298}
7299
7300// Device tuning
7301static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) {
7302 if (device->mmvq_mode == 1) {
7303 return true;
7304 } else if (device->mmvq_mode == -1) {
7305 return false;
7306 }
7307
7308 // General performance issue with q3_k and q6_k due to 2-byte alignment
7309 if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
7310 return false;
7311 }
7312
7313 // MMVQ is generally good for batches
7314 if (n > 1) {
7315 return true;
7316 }
7317
7318 // Quantization overhead is not worth it for small k
7319 switch (device->vendor_id) {
7320 case VK_VENDOR_ID_NVIDIA:
7321 if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
7322 return true;
7323 }
7324
7325 if (k <= 4096) {
7326 return false;
7327 }
7328
7329 switch (src0_type) {
7330 case GGML_TYPE_MXFP4:
7331 case GGML_TYPE_Q8_0:
7332 return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
7333 default:
7334 return true;
7335 }
7336 case VK_VENDOR_ID_AMD:
7337 if (k < 2048) {
7338 return false;
7339 }
7340
7341 switch (src0_type) {
7342 case GGML_TYPE_Q8_0:
7343 return device->architecture == vk_device_architecture::AMD_GCN;
7344 default:
7345 return true;
7346 }
7347 case VK_VENDOR_ID_INTEL:
7348 if (k < 2048) {
7349 return false;
7350 }
7351
7352 switch (src0_type) {
7353 // From tests on A770 Linux, may need more tuning
7354 case GGML_TYPE_Q4_0:
7355 case GGML_TYPE_Q5_1:
7356 return false;
7357 default:
7358 return true;
7359 }
7360 default:
7361 return true;
7362 }
7363
7364 GGML_UNUSED(m);
7365}
7366
7367static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7368 ggml_tensor * dst = cgraph->nodes[node_idx];
7369 const ggml_tensor * src0 = dst->src[0];
7370 const ggml_tensor * src1 = dst->src[1];
7371
7372 VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7373 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7374 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7375 std::cerr << ")),)");
7376 GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
7377 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
7378
7379 const uint64_t ne00 = src0->ne[0];
7380 const uint64_t ne01 = src0->ne[1];
7381 const uint64_t ne02 = src0->ne[2];
7382 const uint64_t ne03 = src0->ne[3];
7383
7384 const uint64_t ne10 = src1->ne[0];
7385 const uint64_t ne11 = src1->ne[1];
7386 const uint64_t ne12 = src1->ne[2];
7387 const uint64_t ne13 = src1->ne[3];
7388
7389 const uint64_t ne20 = dst->ne[0];
7390 const uint64_t ne21 = dst->ne[1];
7391 // const uint64_t ne22 = dst->ne[2];
7392 // const uint64_t ne23 = dst->ne[3];
7393
7394 const uint64_t r2 = ne12 / ne02;
7395 const uint64_t r3 = ne13 / ne03;
7396
7397 // batch_n indicates that we need to compute a few vector results, and this assumes
7398 // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
7399 GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
7400 bool batch_n = ne11 > 1;
7401
7402 const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
7403 const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
7404
7405 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
7406 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type);
7407
7408 vk_pipeline to_fp16_vk_0 = nullptr;
7409 vk_pipeline to_fp16_vk_1 = nullptr;
7410 if (x_non_contig) {
7411 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
7412 }
7413 if (y_non_contig) {
7414 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
7415 } else {
7416 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7417 }
7418
7419 // Check for mmq first
7420 vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr;
7421 vk_pipeline to_q8_1 = nullptr;
7422
7423 if (dmmv == nullptr) {
7424 // Fall back to f16 dequant mul mat
7425 dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00);
7426 quantize_y = false;
7427 }
7428
7429 if (quantize_y) {
7430 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
7431 }
7432
7433 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
7434 dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
7435 }
7436
7437 const bool qx_needs_dequant = x_non_contig;
7438 const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
7439
7440 // Not implemented
7441 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7442
7443 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7444 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7445 GGML_ASSERT(dmmv != nullptr);
7446
7447 const uint64_t x_ne = ggml_nelements(src0);
7448 const uint64_t y_ne = ggml_nelements(src1);
7449
7450 const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
7451 const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
7452 const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
7453 (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
7454
7455 {
7456 if (
7457 (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
7458 (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
7459 GGML_ABORT("Requested preallocation size is too large");
7460 }
7461 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
7462 ctx->prealloc_size_x = x_sz;
7463 ggml_vk_preallocate_buffers(ctx, subctx);
7464 }
7465 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
7466 ctx->prealloc_size_y = y_sz;
7467 ggml_vk_preallocate_buffers(ctx, subctx);
7468 }
7469
7470 // Request descriptor sets
7471 if (qx_needs_dequant) {
7472 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
7473 }
7474 if (qy_needs_dequant) {
7475 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
7476 }
7477 if (quantize_y) {
7478 ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
7479 }
7480 ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
7481 }
7482
7483 vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
7484 vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
7485 vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
7486 vk_subbuffer d_X, d_Y;
7487
7488 if (qx_needs_dequant) {
7489 d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
7490 } else {
7491 d_X = d_Qx;
7492 GGML_ASSERT(qx_sz == x_sz);
7493 }
7494 if (qy_needs_dequant || quantize_y) {
7495 d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
7496 } else {
7497 d_Y = d_Qy;
7498 }
7499
7500 if (x_non_contig) {
7501 if (ctx->prealloc_x_need_sync) {
7502 ggml_vk_sync_buffers(ctx, subctx);
7503 }
7504
7505 GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
7506 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);
7507 }
7508 if (y_non_contig) {
7509 GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
7510 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
7511 ctx->prealloc_y_last_tensor_used != src1) {
7512 if (ctx->prealloc_y_need_sync) {
7513 ggml_vk_sync_buffers(ctx, subctx);
7514 }
7515 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
7516 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
7517 ctx->prealloc_y_last_tensor_used = src1;
7518 }
7519 }
7520 if (quantize_y) {
7521 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7522 ctx->prealloc_y_last_tensor_used != src1) {
7523 if (ctx->prealloc_y_need_sync) {
7524 ggml_vk_sync_buffers(ctx, subctx);
7525 }
7526 ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
7527 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7528 ctx->prealloc_y_last_tensor_used = src1;
7529 }
7530 }
7531
7532 // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
7533 uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
7534 uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
7535 uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
7536
7537 if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
7538 stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
7539 }
7540
7541 if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
7542 stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
7543 }
7544
7545 const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
7546
7547 uint32_t groups_x = ne01;
7548 uint32_t groups_z = 1;
7549
7550 if (ne01 > max_groups_x) {
7551 groups_z = 64;
7552 groups_x = CEIL_DIV(groups_x, groups_z);
7553 }
7554
7555 uint32_t fusion_flags = 0;
7556
7557 vk_subbuffer d_F0 = d_D;
7558 if (ctx->num_additional_fused_ops > 0) {
7559 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
7560 const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
7561
7562 d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
7563 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
7564 }
7565
7566 vk_subbuffer d_F1 = d_D;
7567 if (ctx->num_additional_fused_ops == 2) {
7568 const ggml_tensor * add = cgraph->nodes[node_idx + 2];
7569 const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];
7570
7571 d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
7572 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
7573 }
7574
7575 // compute
7576 const vk_mat_vec_push_constants pc = {
7577 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
7578 stride_batch_x, stride_batch_y, stride_batch_d,
7579 fusion_flags,
7580 (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
7581 };
7582 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
7583 {
7584 d_X,
7585 d_Y,
7586 d_D,
7587 d_F0,
7588 d_F1,
7589 },
7590 pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
7591
7592 if (x_non_contig) {
7593 ctx->prealloc_x_need_sync = true;
7594 }
7595 if (y_non_contig || quantize_y) {
7596 ctx->prealloc_y_need_sync = true;
7597 }
7598}
7599
7600static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7601 ggml_tensor * dst = cgraph->nodes[node_idx];
7602 const ggml_tensor * src0 = dst->src[0];
7603 const ggml_tensor * src1 = dst->src[1];
7604 VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7605 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7606 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7607 std::cerr << "))");
7608 GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
7609 GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT
7610 GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT
7611 GGML_ASSERT(src0->type == GGML_TYPE_F16);
7612 GGML_ASSERT(src1->type == GGML_TYPE_F32);
7613
7614 const uint64_t ne00 = src0->ne[0];
7615 const uint64_t ne01 = src0->ne[1];
7616 const uint64_t ne02 = src0->ne[2];
7617 // const uint64_t ne03 = src0->ne[3];
7618
7619 //const uint64_t ne10 = src1->ne[0];
7620 const uint64_t ne11 = src1->ne[1];
7621 const uint64_t ne12 = src1->ne[2];
7622 // const uint64_t ne13 = src1->ne[3];
7623
7624 GGML_ASSERT(ne11 == 1);
7625
7626 // With grouped query attention there are > 1 Q matrices per K, V matrix.
7627 uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
7628 if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
7629 gqa_ratio = 1;
7630 }
7631
7632 vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1];
7633
7634 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
7635 pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
7636 }
7637
7638 {
7639 // Request descriptor sets
7640 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7641 }
7642
7643 vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
7644 vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
7645 vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
7646
7647 vk_subbuffer d_F0 = d_D;
7648
7649 uint32_t fusion_flags = 0;
7650
7651 if (ctx->num_additional_fused_ops > 0) {
7652 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
7653 const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
7654
7655 d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
7656 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
7657 }
7658
7659 vk_subbuffer d_F1 = d_D;
7660 if (ctx->num_additional_fused_ops > 1) {
7661 const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
7662
7663 d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
7664 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
7665 }
7666
7667 // compute
7668
7669 vk_mat_vec_p021_push_constants pc = {
7670 (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,
7671 0, 0, fusion_flags
7672 };
7673
7674 init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
7675
7676 uint32_t workgroups_z = (uint32_t)ne12;
7677 // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
7678 if (gqa_ratio > 1) {
7679 workgroups_z /= gqa_ratio;
7680 }
7681
7682 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7683 {
7684 d_Qx,
7685 d_Qy,
7686 d_D,
7687 d_F0,
7688 d_F1,
7689 }, pc, { 1, (uint32_t)ne01, workgroups_z });
7690}
7691
7692static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7693 ggml_tensor * dst = cgraph->nodes[node_idx];
7694 const ggml_tensor * src0 = dst->src[0];
7695 const ggml_tensor * src1 = dst->src[1];
7696 VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7697 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7698 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
7699 std::cerr << "))");
7700 GGML_ASSERT(!ggml_is_transposed(src0));
7701 GGML_ASSERT(!ggml_is_transposed(src1));
7702 GGML_ASSERT(!ggml_is_permuted(src0));
7703 GGML_ASSERT(src0->type == GGML_TYPE_F16);
7704 GGML_ASSERT(src1->type == GGML_TYPE_F32);
7705
7706 const uint64_t ne00 = src0->ne[0];
7707 const uint64_t ne01 = src0->ne[1];
7708 const uint64_t ne02 = src0->ne[2];
7709 const uint64_t ne03 = src0->ne[3];
7710
7711 const uint64_t nb01 = src0->nb[1];
7712 const uint64_t nb02 = src0->nb[2];
7713
7714 const uint64_t nb12 = src1->nb[2];
7715
7716 // const uint64_t ne10 = src1->ne[0];
7717 const uint64_t ne11 = src1->ne[1];
7718 const uint64_t ne12 = src1->ne[2];
7719 // const uint64_t ne13 = src1->ne[3];
7720
7721 const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t));
7722 const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float));
7723 const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float));
7724
7725 GGML_ASSERT(ne11 == 1);
7726 GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op
7727
7728 const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
7729 const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
7730 const uint32_t channel_stride_y = nb12 / sizeof(float);
7731
7732 vk_pipeline pipeline = ctx->device->pipeline_mul_mat_vec_nc_f16_f32;
7733 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
7734 pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
7735 }
7736
7737 {
7738 // Request descriptor sets
7739 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7740 }
7741
7742 vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
7743 vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
7744 vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
7745 vk_subbuffer d_F0 = d_D;
7746
7747 uint32_t fusion_flags = 0;
7748
7749 if (ctx->num_additional_fused_ops > 0) {
7750 const ggml_tensor * add = cgraph->nodes[node_idx + 1];
7751 const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
7752
7753 d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
7754 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
7755 }
7756
7757 vk_subbuffer d_F1 = d_D;
7758 if (ctx->num_additional_fused_ops > 1) {
7759 const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
7760
7761 d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
7762 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
7763 }
7764
7765 // compute
7766 vk_mat_vec_nc_push_constants pc = {
7767 (uint32_t)ne00, (uint32_t)ne01,
7768 row_stride_x, channel_stride_x, channel_stride_y,
7769 (uint32_t)(ne12 / ne02), (uint32_t)ne12,
7770 0, 0,
7771 nb03, nb13, nb23, fusion_flags
7772 };
7773
7774 init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
7775
7776 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
7777 {
7778 d_Qx,
7779 d_Qy,
7780 d_D,
7781 d_F0,
7782 d_F1,
7783 }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
7784}
7785
7786static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
7787 ggml_tensor * dst = cgraph->nodes[node_idx];
7788 ggml_tensor * src0 = dst->src[0];
7789 ggml_tensor * src1 = dst->src[1];
7790 VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
7791
7792 // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases
7793 // where the M dimension is very large.
7794 // Split_k doesn't work with M splitting.
7795 // This only supports batchsize == 1.
7796 const size_t nbytes = ggml_nbytes(src0);
7797 const bool needs_split = dst->ne[2] == 1 && dst->ne[3] == 1 && nbytes > ctx->device->properties.limits.maxStorageBufferRange;
7798 if (needs_split) {
7799 // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets)
7800 const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]);
7801 uint32_t m_offset = 0;
7802 while (m_offset < dst->ne[0]) {
7803 const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset));
7804 ggml_tensor dst2 = *dst;
7805 ggml_tensor src02 = *src0;
7806
7807 dst2.view_src = dst->view_src ? dst->view_src : dst;
7808 src02.view_src = src0->view_src ? src0->view_src : src0;
7809
7810 dst2.view_offs += m_offset * dst->nb[0];
7811 src02.view_offs += m_offset * src0->nb[1];
7812 dst2.ne[0] = cur_M_size;
7813 src02.ne[1] = cur_M_size;
7814
7815 ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true);
7816
7817 m_offset += cur_M_size;
7818 }
7819 } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
7820 // detect 0213 permutation, and batch size of 1
7821 src0->nb[0] <= src0->nb[2] &&
7822 src0->nb[2] <= src0->nb[1] &&
7823 src0->nb[1] <= src0->nb[3] &&
7824 src1->nb[0] <= src1->nb[2] &&
7825 src1->nb[2] <= src1->nb[1] &&
7826 src1->nb[1] <= src1->nb[3] &&
7827 src0->ne[3] == 1 &&
7828 src1->ne[3] == 1) {
7829 ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
7830 } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
7831 !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
7832 ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
7833 // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
7834 // when ne12 and ne13 are one.
7835 } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
7836 (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
7837 ggml_vk_mul_mat_vec_q_f16(ctx, subctx, cgraph, node_idx);
7838 } else {
7839 ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false);
7840 }
7841}
7842
7843static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
7844 VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
7845 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
7846 std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
7847 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)");
7848 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
7849 GGML_ASSERT(ids->type == GGML_TYPE_I32);
7850
7851 const uint64_t ne00 = src0->ne[0];
7852 const uint64_t ne01 = src0->ne[1];
7853 const uint64_t ne02 = src0->ne[2];
7854 // const uint64_t ne03 = src0->ne[3];
7855
7856 const uint64_t ne10 = src1->ne[0];
7857 const uint64_t ne11 = src1->ne[1];
7858 const uint64_t ne12 = src1->ne[2];
7859 const uint64_t ne13 = src1->ne[3];
7860
7861 const uint64_t nei0 = ids->ne[0];
7862 const uint64_t nei1 = ids->ne[1];
7863
7864 const uint32_t nbi0 = ids->nb[0];
7865 const uint32_t nbi1 = ids->nb[1];
7866 const uint32_t nbi2 = ids->nb[2];
7867
7868 const uint64_t ne20 = dst->ne[0];
7869 const uint64_t ne21 = dst->ne[1];
7870 // const uint64_t ne22 = dst->ne[2];
7871 // const uint64_t ne23 = dst->ne[3];
7872
7873 const uint64_t n_as = ne02;
7874
7875 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
7876 ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
7877 ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
7878 ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
7879
7880 vk_buffer d_Qx = nullptr;
7881 size_t qx_buf_offset = 0;
7882 vk_buffer d_Qy = nullptr;
7883 size_t qy_buf_offset = 0;
7884 vk_buffer d_ids = nullptr;
7885 size_t ids_buf_offset = 0;
7886
7887 bool src0_uma = false;
7888 bool src1_uma = false;
7889 bool ids_uma = false;
7890
7891 if (ctx->device->uma) {
7892 ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset);
7893 ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset);
7894 ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
7895 src0_uma = d_Qx != nullptr;
7896 src1_uma = d_Qy != nullptr;
7897 ids_uma = d_ids != nullptr;
7898 }
7899
7900 // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
7901 const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
7902 !ggml_vk_dim01_contiguous(src0);
7903 const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
7904 (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
7905 !ggml_vk_dim01_contiguous(src1);
7906
7907 // If src0 is BF16, try to use a BF16 x BF16 multiply
7908 ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
7909
7910 const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
7911
7912 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0;
7913
7914 // Check for mmq first
7915 vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
7916
7917 if (mmp == nullptr) {
7918 // Fall back to f16 dequant mul mat
7919 mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
7920 quantize_y = false;
7921 }
7922
7923 const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
7924 const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
7925
7926 if (qx_needs_dequant) {
7927 // Fall back to dequant + f16 mulmat
7928 mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
7929 }
7930
7931 // Not implemented
7932 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7933
7934 const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
7935 const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
7936
7937 vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
7938
7939 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
7940 pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
7941 }
7942 // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
7943 uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
7944 const uint64_t x_ne = ggml_nelements(src0);
7945 const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
7946 const uint64_t d_ne = ggml_nelements(dst);
7947
7948 const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
7949 const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
7950 const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
7951 const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
7952 const uint64_t ids_sz = nbi2;
7953 const uint64_t d_sz = sizeof(float) * d_ne;
7954
7955 vk_pipeline to_fp16_vk_0 = nullptr;
7956 vk_pipeline to_fp16_vk_1 = nullptr;
7957 vk_pipeline to_q8_1 = nullptr;
7958
7959 if (x_non_contig) {
7960 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
7961 } else {
7962 to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
7963 }
7964 if (y_non_contig) {
7965 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
7966 } else {
7967 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
7968 }
7969 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
7970 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
7971
7972 if (quantize_y) {
7973 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
7974 }
7975 vk_pipeline count_experts = ctx->device->pipeline_count_experts;
7976
7977 uint32_t expert_count_size = sizeof(uint32_t) * n_as;
7978
7979 {
7980 if (
7981 (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
7982 (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
7983 GGML_ABORT("Requested preallocation size is too large");
7984 }
7985 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
7986 ctx->prealloc_size_x = x_sz;
7987 ggml_vk_preallocate_buffers(ctx, subctx);
7988 }
7989 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
7990 ctx->prealloc_size_y = y_sz;
7991 ggml_vk_preallocate_buffers(ctx, subctx);
7992 }
7993 if (ctx->prealloc_size_split_k < expert_count_size) {
7994 ctx->prealloc_size_split_k = expert_count_size;
7995 ggml_vk_preallocate_buffers(ctx, subctx);
7996 }
7997
7998 // Request descriptor sets
7999 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
8000 if (qx_needs_dequant) {
8001 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
8002 }
8003 if (qy_needs_dequant) {
8004 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
8005 }
8006 if (quantize_y) {
8007 ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
8008 }
8009 ggml_pipeline_request_descriptor_sets(ctx, count_experts, 1);
8010 }
8011
8012 vk_buffer d_D = dst_buf_ctx->dev_buffer;
8013 const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
8014 GGML_ASSERT(d_D != nullptr);
8015 vk_buffer d_X;
8016 uint64_t x_buf_offset = 0;
8017 vk_buffer d_Y;
8018 uint64_t y_buf_offset = 0;
8019 if (!src0_uma) {
8020 d_Qx = src0_buf_ctx->dev_buffer;
8021 qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
8022 GGML_ASSERT(d_Qx != nullptr);
8023 }
8024 if (!src1_uma) {
8025 d_Qy = src1_buf_ctx->dev_buffer;
8026 qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
8027 GGML_ASSERT(d_Qy != nullptr);
8028 }
8029 if (!ids_uma) {
8030 d_ids = ids_buf_ctx->dev_buffer;
8031 ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
8032 GGML_ASSERT(d_ids != nullptr);
8033 }
8034 if (qx_needs_dequant) {
8035 d_X = ctx->prealloc_x;
8036 GGML_ASSERT(d_X->size >= x_sz);
8037 } else {
8038 d_X = d_Qx;
8039 x_buf_offset = qx_buf_offset;
8040 GGML_ASSERT(qx_sz == x_sz);
8041 }
8042 if (qy_needs_dequant) {
8043 d_Y = ctx->prealloc_y;
8044 GGML_ASSERT(d_Y->size >= y_sz);
8045 } else if (quantize_y) {
8046 d_Y = ctx->prealloc_y;
8047 GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz, 144) * 144);
8048 } else {
8049 d_Y = d_Qy;
8050 y_buf_offset = qy_buf_offset;
8051 GGML_ASSERT(qy_sz == y_sz);
8052 }
8053
8054 if (x_non_contig || qx_needs_dequant) {
8055 if (ctx->prealloc_x_need_sync) {
8056 ggml_vk_sync_buffers(ctx, subctx);
8057 }
8058 }
8059 // Count how many times each expert is used
8060 vk_subbuffer expert_count_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
8061 if (ctx->prealloc_split_k_need_sync) {
8062 ggml_vk_sync_buffers(ctx, subctx);
8063 }
8064 {
8065 const std::vector<uint32_t> pc = { (uint32_t)nei0,
8066 (uint32_t)nei1,
8067 (uint32_t)(nbi0 / ggml_type_size(ids->type)),
8068 (uint32_t)(nbi1 / ggml_type_size(ids->type)),
8069 (uint32_t)(get_misalign_bytes(ctx, ids) / ggml_type_size(ids->type)) };
8070 ggml_vk_dispatch_pipeline(ctx, subctx, count_experts,
8071 { vk_subbuffer{ d_ids, ids_buf_offset, ids_sz }, expert_count_buf }, pc, { (uint32_t)n_as, 1, 1});
8072 }
8073
8074 if (x_non_contig) {
8075 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0));
8076 } else if (qx_needs_dequant) {
8077 const std::vector<uint32_t> pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) };
8078 ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0,
8079 { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_X, 0, x_sz } }, pc, { (uint32_t)x_ne, 1, 1});
8080 }
8081 if (y_non_contig) {
8082 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
8083 ctx->prealloc_y_last_tensor_used != src1) {
8084 if (ctx->prealloc_y_need_sync) {
8085 ggml_vk_sync_buffers(ctx, subctx);
8086 }
8087 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
8088 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
8089 ctx->prealloc_y_last_tensor_used = src1;
8090 }
8091 }
8092 if (quantize_y) {
8093 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
8094 ctx->prealloc_y_last_tensor_used != src1) {
8095 if (ctx->prealloc_y_need_sync) {
8096 ggml_vk_sync_buffers(ctx, subctx);
8097 }
8098 ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
8099 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
8100 ctx->prealloc_y_last_tensor_used = src1;
8101 }
8102 }
8103 ggml_vk_sync_buffers(ctx, subctx);
8104
8105 uint32_t stride_batch_x = ne00*ne01;
8106 uint32_t stride_batch_y = ne10*ne11;
8107
8108 if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
8109 stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
8110 }
8111
8112 if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
8113 stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
8114 }
8115
8116 // compute
8117 ggml_vk_matmul_id(
8118 ctx, subctx, pipeline,
8119 { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
8120 { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
8121 ne01, ne21, ne10, ne10, ne10, ne01,
8122 stride_batch_x, stride_batch_y, ne20*ne21,
8123 n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
8124 ); // NOLINT
8125
8126 if (x_non_contig || qx_needs_dequant) {
8127 ctx->prealloc_x_need_sync = true;
8128 }
8129 if (y_non_contig || quantize_y) {
8130 ctx->prealloc_y_need_sync = true;
8131 }
8132 ctx->prealloc_split_k_need_sync = true;
8133}
8134
8135static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
8136 ggml_tensor * dst = cgraph->nodes[node_idx];
8137 ggml_tensor * src0 = dst->src[0];
8138 ggml_tensor * src1 = dst->src[1];
8139 ggml_tensor * ids = dst->src[2];
8140 VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
8141 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
8142 std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3];
8143 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
8144 std::cerr << "))");
8145 GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT
8146 GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT
8147 GGML_ASSERT(ids->type == GGML_TYPE_I32);
8148
8149 const uint64_t ne00 = src0->ne[0];
8150 const uint64_t ne01 = src0->ne[1];
8151 // const uint64_t ne02 = src0->ne[2];
8152 // const uint64_t ne03 = src0->ne[3];
8153
8154 const uint64_t ne10 = src1->ne[0];
8155 const uint64_t ne11 = src1->ne[1];
8156 const uint64_t ne12 = src1->ne[2];
8157 // const uint64_t ne13 = src1->ne[3];
8158
8159 const uint64_t nei0 = ids->ne[0];
8160 const uint64_t nei1 = ids->ne[1];
8161 const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
8162
8163 const uint64_t ne20 = dst->ne[0];
8164 const uint64_t ne21 = dst->ne[1];
8165 // const uint64_t ne22 = dst->ne[2];
8166 // const uint64_t ne23 = dst->ne[3];
8167
8168 const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
8169 const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
8170
8171 const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
8172 bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
8173
8174 vk_pipeline to_fp16_vk_0 = nullptr;
8175 vk_pipeline to_fp16_vk_1 = nullptr;
8176 if (x_non_contig) {
8177 to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
8178 }
8179 if (y_non_contig) {
8180 to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
8181 } else {
8182 to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
8183 }
8184
8185 // Check for mmq first
8186 vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
8187 vk_pipeline to_q8_1 = nullptr;
8188
8189 if (dmmv == nullptr) {
8190 // Fall back to f16 dequant mul mat
8191 dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
8192 quantize_y = false;
8193 }
8194
8195 if (quantize_y) {
8196 to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
8197 }
8198
8199 const bool qx_needs_dequant = x_non_contig;
8200 const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
8201
8202 if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
8203 dmmv = ggml_vk_get_64b_indexing_pipeline(ctx, dmmv);
8204 }
8205
8206 // Not implemented
8207 GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
8208 GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
8209 GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
8210 GGML_ASSERT(dmmv != nullptr);
8211
8212 const uint64_t x_ne = ggml_nelements(src0);
8213 const uint64_t y_ne = ggml_nelements(src1);
8214
8215 const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
8216 const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
8217 const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
8218 (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
8219
8220 {
8221 if (
8222 (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
8223 (qy_needs_dequant && y_sz > ctx->device->properties.limits.maxStorageBufferRange)) {
8224 GGML_ABORT("Requested preallocation size is too large");
8225 }
8226 if (qx_needs_dequant && ctx->prealloc_size_x < x_sz) {
8227 ctx->prealloc_size_x = x_sz;
8228 ggml_vk_preallocate_buffers(ctx, subctx);
8229 }
8230 if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
8231 ctx->prealloc_size_y = y_sz;
8232 ggml_vk_preallocate_buffers(ctx, subctx);
8233 }
8234
8235 // Request descriptor sets
8236 if (qx_needs_dequant) {
8237 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
8238 }
8239 if (qy_needs_dequant) {
8240 ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
8241 }
8242 if (quantize_y) {
8243 ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
8244 }
8245 ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
8246 }
8247
8248 vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
8249 vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
8250 vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
8251 vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);
8252 vk_subbuffer d_F0 = d_D;
8253 vk_subbuffer d_X, d_Y;
8254
8255 if (qx_needs_dequant) {
8256 d_X = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
8257 } else {
8258 d_X = d_Qx;
8259 }
8260 if (qy_needs_dequant || quantize_y) {
8261 d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
8262 } else {
8263 d_Y = d_Qy;
8264 }
8265
8266 if (x_non_contig) {
8267 if (ctx->prealloc_x_need_sync) {
8268 ggml_vk_sync_buffers(ctx, subctx);
8269 }
8270 }
8271
8272 if (x_non_contig) {
8273 GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment));
8274 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, d_Qx, d_X);
8275 }
8276 if (y_non_contig) {
8277 GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
8278 if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
8279 ctx->prealloc_y_last_tensor_used != src1) {
8280 if (ctx->prealloc_y_need_sync) {
8281 ggml_vk_sync_buffers(ctx, subctx);
8282 }
8283 ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
8284 ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
8285 ctx->prealloc_y_last_tensor_used = src1;
8286 }
8287 }
8288 if (quantize_y) {
8289 if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
8290 ctx->prealloc_y_last_tensor_used != src1) {
8291 if (ctx->prealloc_y_need_sync) {
8292 ggml_vk_sync_buffers(ctx, subctx);
8293 }
8294 ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
8295 ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
8296 ctx->prealloc_y_last_tensor_used = src1;
8297 }
8298 }
8299
8300 uint32_t stride_batch_y = ne10*ne11;
8301
8302 if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
8303 stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
8304 }
8305
8306 const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
8307
8308 uint32_t groups_x = ne01;
8309 uint32_t groups_z = 1;
8310
8311 if (ne01 > max_groups_x) {
8312 groups_z = 64;
8313 groups_x = CEIL_DIV(groups_x, groups_z);
8314 }
8315
8316 uint32_t fusion_flags = 0;
8317
8318 if (ctx->num_additional_fused_ops > 0) {
8319 const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
8320
8321 d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
8322
8323 if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
8324 fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;
8325 } else {
8326 GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
8327 fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
8328 }
8329 }
8330
8331 vk_subbuffer d_F1 = d_D;
8332 if (ctx->num_additional_fused_ops > 1) {
8333 const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];
8334
8335 d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);
8336 fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
8337 }
8338
8339 // Loop over the batch dimension
8340 for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
8341 const vk_mat_vec_id_push_constants pc = {
8342 (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
8343 (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
8344 fusion_flags,
8345 (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
8346 };
8347 ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
8348 {
8349 d_X,
8350 d_Y,
8351 d_D,
8352 d_F0,
8353 d_F1,
8354 d_ids,
8355 },
8356 pc, { groups_x, (uint32_t)nei0, groups_z });
8357 }
8358
8359 if (x_non_contig) {
8360 ctx->prealloc_x_need_sync = true;
8361 }
8362 if (y_non_contig || quantize_y) {
8363 ctx->prealloc_y_need_sync = true;
8364 }
8365}
8366
8367static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int node_idx) {
8368 ggml_tensor * dst = cgraph->nodes[node_idx];
8369 ggml_tensor * src0 = dst->src[0];
8370 ggml_tensor * src2 = dst->src[2];
8371 return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
8372}
8373
8374static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
8375 ggml_tensor * dst = cgraph->nodes[node_idx];
8376 ggml_tensor * src0 = dst->src[0];
8377 ggml_tensor * src1 = dst->src[1];
8378 ggml_tensor * src2 = dst->src[2];
8379 VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")");
8380 if (ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
8381 ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, cgraph, node_idx);
8382 } else {
8383 ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst);
8384 }
8385}
8386
8387static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
8388 // Needs to be kept up to date on shader changes
8389 GGML_UNUSED(hsv);
8390 const uint32_t wg_size = scalar_flash_attention_workgroup_size;
8391 const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
8392 const uint32_t Bc = scalar_flash_attention_Bc;
8393
8394 const uint32_t tmpsh = wg_size * sizeof(float);
8395 const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
8396
8397 const uint32_t masksh = Bc * Br * sizeof(float);
8398
8399 const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
8400
8401 const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
8402 const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
8403
8404 VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
8405
8406 return supported;
8407}
8408
8409static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
8410 // Needs to be kept up to date on shader changes
8411 GGML_UNUSED(hsv);
8412 const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
8413 const uint32_t Br = rows_cols[0];
8414 const uint32_t Bc = rows_cols[1];
8415
8416 const uint32_t MatBr = 16, MatBc = 16;
8417
8418 const uint32_t row_split = Bc / MatBc;
8419
8420 const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
8421
8422 const uint32_t acctype = f32acc ? 4 : 2;
8423 const uint32_t f16vec4 = 8;
8424
8425 const uint32_t qstride = hsk_pad / 4 + 2;
8426 const uint32_t Qf = Br * qstride * f16vec4;
8427
8428 const uint32_t psh_stride = Br / 4 + 2;
8429 const uint32_t Psh = Bc * psh_stride * f16vec4;
8430
8431 const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
8432 const uint32_t sfsh = Bc * sfshstride * acctype;
8433
8434 const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
8435 const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
8436 const uint32_t vsh_stride = MatBc / 4 * row_split;
8437 const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
8438
8439 const uint32_t slope = Br * acctype;
8440
8441 const uint32_t total_size = Qf + Psh + sfsh + ksh + slope;
8442 const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
8443
8444 VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
8445
8446 return supported;
8447}
8448
8449static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst) {
8450 VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
8451 std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
8452 std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
8453 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
8454 if (sinks) {
8455 std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3];
8456 }
8457 std::cerr << "))");
8458
8459 GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
8460 GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
8461 GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
8462 GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
8463 GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
8464 GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
8465 GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8466 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8467
8468 const uint32_t nem0 = mask ? mask->ne[0] : 0;
8469 const uint32_t nem1 = mask ? mask->ne[1] : 0;
8470 const uint32_t nem2 = mask ? mask->ne[2] : 0;
8471 const uint32_t nem3 = mask ? mask->ne[3] : 0;
8472
8473 const uint32_t HSK = nek0;
8474 const uint32_t HSV = nev0;
8475 uint32_t N = neq1;
8476 const uint32_t KV = nek1;
8477
8478 GGML_ASSERT(ne0 == HSV);
8479 GGML_ASSERT(ne2 == N);
8480
8481 // input tensor rows must be contiguous
8482 GGML_ASSERT(nbq0 == ggml_type_size(q->type));
8483 GGML_ASSERT(nbk0 == ggml_type_size(k->type));
8484 GGML_ASSERT(nbv0 == ggml_type_size(v->type));
8485
8486 GGML_ASSERT(neq0 == HSK);
8487
8488 GGML_ASSERT(neq1 == N);
8489
8490 GGML_ASSERT(nev1 == nek1);
8491
8492 // dst cannot be transposed or permuted
8493 GGML_ASSERT(nb0 == sizeof(float));
8494 GGML_ASSERT(nb0 <= nb1);
8495 GGML_ASSERT(nb1 <= nb2);
8496 GGML_ASSERT(nb2 <= nb3);
8497
8498 assert(dst->type == GGML_TYPE_F32);
8499 assert(q->type == GGML_TYPE_F32);
8500 assert(k->type == v->type);
8501
8502 FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
8503 ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
8504
8505 if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) {
8506 // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
8507 path = FA_SCALAR;
8508 }
8509
8510 if (path == FA_COOPMAT1) {
8511 const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
8512 (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
8513
8514 const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
8515
8516 if (!coopmat_shape_supported || !coopmat_shmem_supported) {
8517 path = FA_SCALAR;
8518 }
8519 }
8520
8521 uint32_t gqa_ratio = 1;
8522 uint32_t qk_ratio = neq2 / nek2;
8523 uint32_t workgroups_x = (uint32_t)neq1;
8524 uint32_t workgroups_y = (uint32_t)neq2;
8525 uint32_t workgroups_z = (uint32_t)neq3;
8526
8527 const bool small_cache = nek1 < 1024;
8528
8529 // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
8530 // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
8531 uint32_t max_gqa;
8532 switch (path) {
8533 case FA_SCALAR:
8534 case FA_COOPMAT1:
8535 // We may switch from coopmat1 to scalar, so use the scalar limit for both
8536 max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
8537 break;
8538 case FA_COOPMAT2:
8539 max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
8540 break;
8541 default:
8542 GGML_ASSERT(0);
8543 }
8544
8545 if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
8546 qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
8547 // grouped query attention - make the N dimension equal to gqa_ratio, reduce
8548 // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
8549 // and change addressing calculations to index Q's dimension 2.
8550 gqa_ratio = qk_ratio;
8551 N = gqa_ratio;
8552 workgroups_y /= gqa_ratio;
8553 }
8554
8555 bool small_rows = N <= get_fa_num_small_rows(path);
8556
8557 // coopmat1 does not actually support "small rows" (it needs 16 rows).
8558 // So use scalar instead.
8559 if (small_rows && path == FA_COOPMAT1) {
8560 path = FA_SCALAR;
8561 }
8562
8563 // scalar is faster than coopmat2 when N==1
8564 if (N == 1 && path == FA_COOPMAT2) {
8565 path = FA_SCALAR;
8566 }
8567
8568 // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
8569 if (path == FA_SCALAR &&
8570 !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
8571 small_rows = true;
8572 }
8573
8574 const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
8575 uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
8576 uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
8577
8578 // For F32, the shader treats it as a block of size 4 (for vec4 loads)
8579 if (k->type == GGML_TYPE_F32) {
8580 k_stride /= 4;
8581 }
8582 if (v->type == GGML_TYPE_F32) {
8583 v_stride /= 4;
8584 }
8585
8586 uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
8587 bool aligned = (KV % alignment) == 0 &&
8588 // the "aligned" shader variant will forcibly align strides, for performance
8589 (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
8590
8591 // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
8592 if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
8593 aligned = false;
8594 }
8595
8596 bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
8597
8598 float scale = 1.0f;
8599 float max_bias = 0.0f;
8600 float logit_softcap = 0.0f;
8601
8602 memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8603 memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8604 memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8605
8606 if (logit_softcap != 0) {
8607 scale /= logit_softcap;
8608 }
8609
8610 // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
8611 bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
8612
8613 uint32_t flags = (use_mask_opt ? 1 : 0) |
8614 (mask != nullptr ? 2 : 0) |
8615 (logit_softcap != 0 ? 4 : 0);
8616
8617 vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
8618
8619 vk_pipeline pipeline = nullptr;
8620
8621 {
8622 std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8623 auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
8624 auto it = pipelines.find(fa_pipeline_state);
8625 if (it != pipelines.end()) {
8626 pipeline = it->second;
8627 } else {
8628 pipelines[fa_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
8629 }
8630 }
8631
8632 assert(pipeline);
8633 // Compile early to initialize wg_denoms.
8634 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
8635
8636 uint32_t split_kv = KV;
8637 uint32_t split_k = 1;
8638
8639 // Use a placeholder core count if one isn't available. split_k is a big help for perf.
8640 const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
8641
8642 // Try to use split_k when KV is large enough to be worth the overhead.
8643 // Must either be a single batch or be using gqa, we can't mix the two.
8644 if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
8645 // Try to run two workgroups per SM.
8646 split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
8647 if (split_k > 1) {
8648 // Try to evenly split KV into split_k chunks, but it needs to be a multiple
8649 // of "align", so recompute split_k based on that.
8650 split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
8651 split_k = CEIL_DIV(KV, split_kv);
8652 }
8653 }
8654
8655 // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
8656 // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
8657 // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
8658 // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
8659 const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
8660 if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
8661 GGML_ABORT("Requested preallocation size is too large");
8662 }
8663 if (ctx->prealloc_size_split_k < split_k_size) {
8664 ctx->prealloc_size_split_k = split_k_size;
8665 ggml_vk_preallocate_buffers(ctx, subctx);
8666 }
8667
8668 auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
8669 const uint32_t Br = rows_cols[0];
8670 const uint32_t Bc = rows_cols[1];
8671
8672 const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
8673 const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
8674
8675 vk_pipeline pipeline_fa_mask_opt = nullptr;
8676 if (use_mask_opt) {
8677 std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8678 auto &pipelines = ctx->device->pipeline_fa_mask_opt;
8679 auto it = pipelines.find({Br, Bc});
8680 if (it != pipelines.end()) {
8681 pipeline_fa_mask_opt = it->second;
8682 } else {
8683 pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
8684 }
8685 assert(pipeline_fa_mask_opt);
8686 ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
8687
8688 if (ctx->prealloc_size_y < mask_opt_size) {
8689 ctx->prealloc_size_y = mask_opt_size;
8690 ggml_vk_preallocate_buffers(ctx, subctx);
8691 }
8692 if (ctx->prealloc_y_need_sync) {
8693 ggml_vk_sync_buffers(ctx, subctx);
8694 }
8695 }
8696
8697 const uint32_t n_head_kv = neq2;
8698 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
8699 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
8700 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
8701
8702 vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
8703 vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
8704 vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
8705 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
8706 vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
8707 vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
8708 vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
8709
8710 uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
8711
8712 if (use_mask_opt)
8713 {
8714 const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
8715 nem0,
8716 nem1,
8717 nem2,
8718 (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
8719 (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
8720 (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
8721 mask_opt_num_dwords,
8722 mask_opt_num_dwords * CEIL_DIV(nem1, Br),
8723 mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
8724 };
8725
8726 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
8727 { mask_buf, mask_opt_buf }, opt_pc,
8728 { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
8729 ggml_vk_sync_buffers(ctx, subctx);
8730 }
8731
8732 const vk_flash_attn_push_constants pc = { N, KV,
8733 (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
8734 (uint32_t)neq2, (uint32_t)neq3,
8735 (uint32_t)nek2, (uint32_t)nek3,
8736 (uint32_t)nev2, (uint32_t)nev3,
8737 nem1, nem2, nem3,
8738 q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
8739 k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
8740 v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
8741 scale, max_bias, logit_softcap,
8742 mask_n_head_log2, m0, m1,
8743 gqa_ratio, split_kv, split_k };
8744
8745 if (split_k > 1) {
8746 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
8747
8748 if (ctx->prealloc_split_k_need_sync) {
8749 ggml_vk_sync_buffers(ctx, subctx);
8750 }
8751 workgroups_x *= pipeline->wg_denoms[0];
8752 vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
8753 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8754 {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
8755 // We only use split_k when group query attention is enabled, which means
8756 // there's no more than one tile of rows (i.e. workgroups_x would have been
8757 // one). We reuse workgroups_x to mean the number of splits, so we need to
8758 // cancel out the divide by wg_denoms[0].
8759 pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
8760
8761 ggml_vk_sync_buffers(ctx, subctx);
8762 const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
8763 ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
8764 {split_k_buf, sinks_buf, dst_buf},
8765 pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
8766 ctx->prealloc_split_k_need_sync = true;
8767 } else {
8768 if (gqa_ratio > 1) {
8769 // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
8770 workgroups_x *= pipeline->wg_denoms[0];
8771 }
8772 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8773 {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
8774 pc, { workgroups_x, workgroups_y, workgroups_z });
8775 }
8776}
8777
8778static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, uint32_t K, uint32_t NPQ) {
8779 auto n_tiles = [&](vk_conv_shapes s) {
8780 return CEIL_DIV(K, vk_conv_block_sizes[s].K)
8781 * CEIL_DIV(NPQ, vk_conv_block_sizes[s].NPQ);
8782 };
8783
8784 // We can't query number of shader cores on Intel, use 32 as a placeholder
8785 // so small convolutions will still choose a smaller tile.
8786 const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
8787
8788 if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
8789 return CONV_SHAPE_128x128;
8790 } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
8791 return CONV_SHAPE_32x256;
8792 } else {
8793 return CONV_SHAPE_64x32;
8794 }
8795}
8796
8797static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) {
8798 switch (op) {
8799 case GGML_OP_GET_ROWS:
8800 GGML_ASSERT(src1->type == GGML_TYPE_I32);
8801 if (src0->type == GGML_TYPE_I32) {
8802 // i32 src only supports i32 result
8803 GGML_ASSERT(dst->type == GGML_TYPE_I32);
8804 return ctx->device->pipeline_get_rows[src0->type];
8805 }
8806 if (dst->type == GGML_TYPE_F16) {
8807 return ctx->device->pipeline_get_rows[src0->type];
8808 }
8809 if (dst->type == GGML_TYPE_F32) {
8810 return ctx->device->pipeline_get_rows_f32[src0->type];
8811 }
8812 return nullptr;
8813 case GGML_OP_ACC:
8814 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8815 return ctx->device->pipeline_acc_f32;
8816 }
8817 return nullptr;
8818 case GGML_OP_ADD:
8819 case GGML_OP_SUB:
8820 case GGML_OP_MUL:
8821 case GGML_OP_DIV:
8822 if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
8823 (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
8824 (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
8825 return nullptr;
8826 }
8827 switch (op) {
8828 case GGML_OP_ADD:
8829 {
8830 if (ctx->num_additional_fused_ops > 0) {
8831 if (ctx->do_add_rms_partials) {
8832 return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops];
8833 } else {
8834 return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops];
8835 }
8836 }
8837 if (ctx->do_add_rms_partials) {
8838 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms;
8839 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8840 } else {
8841 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
8842 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8843 }
8844 }
8845 case GGML_OP_SUB:
8846 {
8847 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
8848 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8849 }
8850 case GGML_OP_MUL:
8851 {
8852 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
8853 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8854 }
8855 case GGML_OP_DIV:
8856 {
8857 auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
8858 return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
8859 }
8860 default:
8861 break;
8862 }
8863 return nullptr;
8864 case GGML_OP_ADD_ID:
8865 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
8866 return ctx->device->pipeline_add_id_f32;
8867 }
8868 return nullptr;
8869 case GGML_OP_CONCAT:
8870 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8871 return ctx->device->pipeline_concat_f32;
8872 }
8873 if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
8874 return ctx->device->pipeline_concat_f16;
8875 }
8876 if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
8877 return ctx->device->pipeline_concat_i32;
8878 }
8879 return nullptr;
8880 case GGML_OP_UPSCALE:
8881 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8882 uint32_t mode = (ggml_get_op_params_i32(dst, 0) & (0xFF | GGML_SCALE_FLAG_ANTIALIAS));
8883 switch (mode) {
8884 case GGML_SCALE_MODE_NEAREST:
8885 return ctx->device->pipeline_upscale_nearest_f32;
8886 case GGML_SCALE_MODE_BILINEAR:
8887 return ctx->device->pipeline_upscale_bilinear_f32;
8888 case GGML_SCALE_MODE_BICUBIC:
8889 return ctx->device->pipeline_upscale_bicubic_f32;
8890 case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS:
8891 return ctx->device->pipeline_upscale_bilinear_antialias_f32;
8892 default:
8893 return nullptr;
8894 }
8895 }
8896 return nullptr;
8897 case GGML_OP_SCALE:
8898 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8899 return ctx->device->pipeline_scale_f32;
8900 }
8901 return nullptr;
8902 case GGML_OP_SQR:
8903 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8904 return ctx->device->pipeline_sqr_f32;
8905 }
8906 return nullptr;
8907 case GGML_OP_SQRT:
8908 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8909 return ctx->device->pipeline_sqrt_f32;
8910 }
8911 return nullptr;
8912 case GGML_OP_SIN:
8913 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8914 return ctx->device->pipeline_sin_f32;
8915 }
8916 return nullptr;
8917 case GGML_OP_COS:
8918 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8919 return ctx->device->pipeline_cos_f32;
8920 }
8921 return nullptr;
8922 case GGML_OP_LOG:
8923 if (src0->type == dst->type &&
8924 (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8925 return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
8926 }
8927 return nullptr;
8928 case GGML_OP_TRI:
8929 if (src0->type == dst->type &&
8930 (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8931 return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
8932 }
8933 return nullptr;
8934 case GGML_OP_DIAG:
8935 if (src0->type == dst->type &&
8936 (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8937 return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
8938 }
8939 return nullptr;
8940 case GGML_OP_CLAMP:
8941 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8942 return ctx->device->pipeline_clamp_f32;
8943 }
8944 return nullptr;
8945 case GGML_OP_PAD:
8946 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8947 return ctx->device->pipeline_pad_f32;
8948 }
8949 return nullptr;
8950 case GGML_OP_ROLL:
8951 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8952 return ctx->device->pipeline_roll_f32;
8953 }
8954 return nullptr;
8955 case GGML_OP_REPEAT:
8956 if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
8957 return ctx->device->pipeline_repeat_f32;
8958 }
8959 return nullptr;
8960 case GGML_OP_REPEAT_BACK:
8961 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8962 return ctx->device->pipeline_repeat_back_f32;
8963 }
8964 return nullptr;
8965 case GGML_OP_CPY:
8966 case GGML_OP_CONT:
8967 case GGML_OP_DUP:
8968 return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
8969 case GGML_OP_SET_ROWS:
8970 if (src1->type == GGML_TYPE_I64) {
8971 return ctx->device->pipeline_set_rows_i64[dst->type];
8972 } else {
8973 return ctx->device->pipeline_set_rows_i32[dst->type];
8974 }
8975 case GGML_OP_SILU_BACK:
8976 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8977 return ctx->device->pipeline_silu_back_f32;
8978 }
8979 return nullptr;
8980 case GGML_OP_NORM:
8981 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8982 return ctx->device->pipeline_norm_f32;
8983 }
8984 return nullptr;
8985 case GGML_OP_GROUP_NORM:
8986 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8987 return ctx->device->pipeline_group_norm_f32;
8988 }
8989 return nullptr;
8990 case GGML_OP_RMS_NORM:
8991 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
8992 if (ctx->do_add_rms_partials) {
8993 return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32;
8994 } else {
8995 return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
8996 }
8997 }
8998 return nullptr;
8999 case GGML_OP_RMS_NORM_BACK:
9000 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9001 return ctx->device->pipeline_rms_norm_back_f32;
9002 }
9003 return nullptr;
9004 case GGML_OP_L2_NORM:
9005 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9006 return ctx->device->pipeline_l2_norm_f32;
9007 }
9008 return nullptr;
9009 case GGML_OP_UNARY:
9010 if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
9011 (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
9012 (src0->type != dst->type)) {
9013 return nullptr;
9014 }
9015
9016 switch (ggml_get_unary_op(dst)) {
9017 case GGML_UNARY_OP_EXP:
9018 return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
9019 case GGML_UNARY_OP_SILU:
9020 return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
9021 case GGML_UNARY_OP_GELU:
9022 return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
9023 case GGML_UNARY_OP_GELU_ERF:
9024 return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16];
9025 case GGML_UNARY_OP_GELU_QUICK:
9026 return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
9027 case GGML_UNARY_OP_RELU:
9028 return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
9029 case GGML_UNARY_OP_XIELU:
9030 return ctx->device->pipeline_xielu[dst->type == GGML_TYPE_F16];
9031 case GGML_UNARY_OP_NEG:
9032 return ctx->device->pipeline_neg[dst->type == GGML_TYPE_F16];
9033 case GGML_UNARY_OP_TANH:
9034 return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
9035 case GGML_UNARY_OP_SIGMOID:
9036 return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
9037 case GGML_UNARY_OP_HARDSIGMOID:
9038 return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16];
9039 case GGML_UNARY_OP_HARDSWISH:
9040 return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16];
9041 case GGML_UNARY_OP_ABS:
9042 return ctx->device->pipeline_abs[dst->type == GGML_TYPE_F16];
9043 case GGML_UNARY_OP_SOFTPLUS:
9044 return ctx->device->pipeline_softplus[dst->type == GGML_TYPE_F16];
9045 case GGML_UNARY_OP_STEP:
9046 return ctx->device->pipeline_step[dst->type == GGML_TYPE_F16];
9047 case GGML_UNARY_OP_ROUND:
9048 return ctx->device->pipeline_round[dst->type == GGML_TYPE_F16];
9049 case GGML_UNARY_OP_CEIL:
9050 return ctx->device->pipeline_ceil[dst->type == GGML_TYPE_F16];
9051 case GGML_UNARY_OP_FLOOR:
9052 return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
9053 case GGML_UNARY_OP_TRUNC:
9054 return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
9055 default:
9056 break;
9057 }
9058 return nullptr;
9059 case GGML_OP_GLU:
9060 if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
9061 (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
9062 (src0->type != dst->type)) {
9063 return nullptr;
9064 }
9065
9066 switch (ggml_get_glu_op(dst)) {
9067 case GGML_GLU_OP_GEGLU:
9068 return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
9069 case GGML_GLU_OP_REGLU:
9070 return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
9071 case GGML_GLU_OP_SWIGLU:
9072 return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
9073 case GGML_GLU_OP_SWIGLU_OAI:
9074 return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
9075 case GGML_GLU_OP_GEGLU_ERF:
9076 return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
9077 case GGML_GLU_OP_GEGLU_QUICK:
9078 return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16];
9079 default:
9080 break;
9081 }
9082 return nullptr;
9083 case GGML_OP_DIAG_MASK_INF:
9084 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9085 return ctx->device->pipeline_diag_mask_inf_f32;
9086 }
9087 return nullptr;
9088 case GGML_OP_SOFT_MAX:
9089 GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
9090 GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
9091
9092 if (ctx->num_additional_fused_ops) {
9093 uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
9094 GGML_ASSERT(idx < num_topk_moe_pipelines);
9095 // use n_experts from push constant if it's not equal to the power of two spec constant
9096 bool use_push = dst->ne[0] != (1u << idx);
9097 return ctx->device->pipeline_topk_moe[idx][use_push];
9098 }
9099
9100 if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
9101 return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
9102 }
9103 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
9104 return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
9105 }
9106 return nullptr;
9107 case GGML_OP_SOFT_MAX_BACK:
9108 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9109 return ctx->device->pipeline_soft_max_back_f32;
9110 }
9111 return nullptr;
9112 case GGML_OP_ROPE:
9113 case GGML_OP_ROPE_BACK:
9114 {
9115 const ggml_tensor *rope = ctx->num_additional_fused_ops == 2 ? dst->src[0]->src[0] : dst;
9116 const int mode = ((const int32_t *) rope->op_params)[2];
9117 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
9118 const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
9119 const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
9120
9121 if (is_neox) {
9122 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9123 return ctx->device->pipeline_rope_neox_f32;
9124 }
9125 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
9126 return ctx->device->pipeline_rope_neox_f32_f16;
9127 }
9128 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
9129 return ctx->device->pipeline_rope_neox_f16;
9130 }
9131 } else if (is_mrope && !is_vision) {
9132 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9133 return ctx->device->pipeline_rope_multi_f32;
9134 }
9135 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
9136 return ctx->device->pipeline_rope_multi_f32_f16;
9137 }
9138 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
9139 return ctx->device->pipeline_rope_multi_f16;
9140 }
9141 } else if (is_vision) {
9142 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9143 return ctx->device->pipeline_rope_vision_f32;
9144 }
9145 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
9146 return ctx->device->pipeline_rope_vision_f16;
9147 }
9148 } else {
9149 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9150 return ctx->device->pipeline_rope_norm_f32;
9151 }
9152 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
9153 return ctx->device->pipeline_rope_norm_f32_f16;
9154 }
9155 if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
9156 return ctx->device->pipeline_rope_norm_f16;
9157 }
9158 }
9159 return nullptr;
9160 }
9161 case GGML_OP_SUM:
9162 case GGML_OP_SUM_ROWS:
9163 case GGML_OP_MEAN:
9164 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9165 return ctx->device->pipeline_sum_rows_f32;
9166 }
9167 return nullptr;
9168 case GGML_OP_CUMSUM:
9169 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9170 if (src0->ne[0] <= 512) {
9171 return ctx->device->pipeline_cumsum_small_f32;
9172 } else {
9173 return ctx->device->pipeline_cumsum_f32;
9174 }
9175 }
9176 return nullptr;
9177 case GGML_OP_SOLVE_TRI:
9178 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9179
9180 vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
9181
9182 vk_pipeline pipeline = nullptr;
9183
9184 {
9185 std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
9186 auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
9187 if (it != ctx->device->pipeline_solve_tri_f32.end()) {
9188 pipeline = it->second;
9189 } else {
9190 ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
9191 }
9192 }
9193
9194 return pipeline;
9195 }
9196 return nullptr;
9197 case GGML_OP_ARGMAX:
9198 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
9199 return ctx->device->pipeline_argmax_f32;
9200 }
9201 return nullptr;
9202 case GGML_OP_COUNT_EQUAL:
9203 if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
9204 return ctx->device->pipeline_count_equal_i32;
9205 }
9206 return nullptr;
9207 case GGML_OP_IM2COL:
9208 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9209 return ctx->device->pipeline_im2col_f32;
9210 }
9211 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
9212 return ctx->device->pipeline_im2col_f32_f16;
9213 }
9214 return nullptr;
9215 case GGML_OP_IM2COL_3D:
9216 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9217 return ctx->device->pipeline_im2col_3d_f32;
9218 }
9219 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
9220 return ctx->device->pipeline_im2col_3d_f32_f16;
9221 }
9222 return nullptr;
9223 case GGML_OP_TIMESTEP_EMBEDDING:
9224 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9225 return ctx->device->pipeline_timestep_embedding_f32;
9226 }
9227 return nullptr;
9228 case GGML_OP_CONV_TRANSPOSE_1D:
9229 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9230 return ctx->device->pipeline_conv_transpose_1d_f32;
9231 }
9232 return nullptr;
9233 case GGML_OP_POOL_2D:
9234 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9235 return ctx->device->pipeline_pool2d_f32;
9236 }
9237 return nullptr;
9238 case GGML_OP_RWKV_WKV6:
9239 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9240 return ctx->device->pipeline_rwkv_wkv6_f32;
9241 }
9242 return nullptr;
9243 case GGML_OP_RWKV_WKV7:
9244 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9245 return ctx->device->pipeline_rwkv_wkv7_f32;
9246 }
9247 return nullptr;
9248 case GGML_OP_SSM_SCAN:
9249 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9250 const uint32_t d_state = src0->ne[0];
9251 if (d_state == 128) {
9252 return ctx->device->pipeline_ssm_scan_f32_d128;
9253 } else if (d_state == 256) {
9254 return ctx->device->pipeline_ssm_scan_f32_d256;
9255 }
9256 }
9257 return nullptr;
9258 case GGML_OP_SSM_CONV:
9259 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9260 return ctx->device->pipeline_ssm_conv_f32;
9261 }
9262 return nullptr;
9263 case GGML_OP_OPT_STEP_ADAMW:
9264 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9265 return ctx->device->pipeline_opt_step_adamw_f32;
9266 }
9267 return nullptr;
9268 case GGML_OP_OPT_STEP_SGD:
9269 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9270 return ctx->device->pipeline_opt_step_sgd_f32;
9271 }
9272 return nullptr;
9273 case GGML_OP_LEAKY_RELU:
9274 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9275 return ctx->device->pipeline_leaky_relu_f32;
9276 }
9277 return nullptr;
9278 case GGML_OP_CONV_2D:
9279 case GGML_OP_CONV_TRANSPOSE_2D:
9280 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9281 uint32_t K = dst->ne[2]; // Cout
9282 uint32_t NPQ = dst->ne[3] * dst->ne[1] * dst->ne[0]; // N * OH * OW
9283 vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, K, NPQ);
9284
9285 bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
9286 uint32_t KW = (uint32_t)src0->ne[0];
9287 uint32_t KH = (uint32_t)src0->ne[1];
9288 uint32_t s0 = (uint32_t)(ggml_get_op_params_i32(dst, 0));
9289 uint32_t s1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 1) : s0;
9290 uint32_t p0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 2) : 0;
9291 uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
9292 uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
9293 uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
9294 vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
9295
9296 std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
9297 if (op == GGML_OP_CONV_2D) {
9298 if (src0->type == GGML_TYPE_F32) {
9299 pipelines = &ctx->device->pipeline_conv2d_f32[shape];
9300 } else if (src0->type == GGML_TYPE_F16) {
9301 pipelines = &ctx->device->pipeline_conv2d_f16_f32[shape];
9302 }
9303 } else if (op == GGML_OP_CONV_TRANSPOSE_2D) {
9304 if (src0->type == GGML_TYPE_F32) {
9305 pipelines = &ctx->device->pipeline_conv_transpose_2d_f32[shape];
9306 } else if (src0->type == GGML_TYPE_F16) {
9307 pipelines = &ctx->device->pipeline_conv_transpose_2d_f16_f32[shape];
9308 }
9309 }
9310
9311 vk_pipeline pipeline = nullptr;
9312
9313 {
9314 std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
9315 auto it = pipelines->find(conv2d_pipeline_state);
9316 if (it != pipelines->end()) {
9317 pipeline = it->second;
9318 } else {
9319 (*pipelines)[conv2d_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
9320 }
9321 }
9322
9323 return pipeline;
9324 }
9325 return nullptr;
9326 case GGML_OP_CONV_2D_DW:
9327 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9328 if (ggml_is_contiguous(src1)) {
9329 return ctx->device->pipeline_conv2d_dw_whcn_f32;
9330 } else if (ggml_is_contiguous_channels(src1)) {
9331 return ctx->device->pipeline_conv2d_dw_cwhn_f32;
9332 }
9333 } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
9334 if (ggml_is_contiguous(src1)) {
9335 return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
9336 } else if (ggml_is_contiguous_channels(src1)) {
9337 return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
9338 }
9339 }
9340 return nullptr;
9341 case GGML_OP_ADD1:
9342 if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
9343 return ctx->device->pipeline_add1_f16_f16;
9344 }
9345 if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
9346 return ctx->device->pipeline_add1_f16_f32;
9347 }
9348 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9349 return ctx->device->pipeline_add1_f32_f32;
9350 }
9351 return nullptr;
9352 case GGML_OP_ARANGE:
9353 if (dst->type == GGML_TYPE_F32) {
9354 return ctx->device->pipeline_arange_f32;
9355 }
9356 return nullptr;
9357 case GGML_OP_FILL:
9358 if (dst->type == GGML_TYPE_F32) {
9359 return ctx->device->pipeline_fill_f32;
9360 }
9361 return nullptr;
9362 default:
9363 return nullptr;
9364 }
9365
9366 GGML_UNUSED(src2);
9367}
9368
9369template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
9370 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
9371 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
9372
9373 p.misalign_offsets = (a_offset << 16) | d_offset;
9374
9375 GGML_UNUSED(src1);
9376 GGML_UNUSED(src2);
9377 GGML_UNUSED(src3);
9378}
9379
9380template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
9381 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
9382 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
9383
9384 p.misalign_offsets = (a_offset << 16) | d_offset;
9385
9386 GGML_UNUSED(src1);
9387 GGML_UNUSED(src2);
9388 GGML_UNUSED(src3);
9389}
9390
9391template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
9392 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
9393 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
9394
9395 p.misalign_offsets = (a_offset << 16) | d_offset;
9396
9397 GGML_UNUSED(src1);
9398 GGML_UNUSED(src2);
9399 GGML_UNUSED(src3);
9400}
9401
9402template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
9403 const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
9404 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
9405
9406 p.misalign_offsets = (a_offset << 16) | d_offset;
9407
9408 GGML_UNUSED(src0);
9409 GGML_UNUSED(src2);
9410 GGML_UNUSED(src3);
9411}
9412
9413template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
9414 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
9415 const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
9416 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
9417
9418 GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
9419
9420 p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
9421
9422 GGML_UNUSED(src2);
9423 GGML_UNUSED(src3);
9424}
9425
9426template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
9427 const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
9428 const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
9429
9430 p.a_offset = a_offset;
9431 p.d_offset = d_offset;
9432
9433 GGML_UNUSED(src1);
9434 GGML_UNUSED(src2);
9435 GGML_UNUSED(src3);
9436}
9437
9438template<typename PC>
9439static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
9440 VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
9441 if (src1 != nullptr) {
9442 std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
9443 }
9444 if (src2 != nullptr) {
9445 std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3];
9446 }
9447 if (src3 != nullptr) {
9448 std::cerr << "), (" << src3 << ", name=" << src3->name << ", type=" << src3->type << ", ne0=" << src3->ne[0] << ", ne1=" << src3->ne[1] << ", ne2=" << src3->ne[2] << ", ne3=" << src3->ne[3] << ", nb0=" << src3->nb[0] << ", nb1=" << src3->nb[1] << ", nb2=" << src3->nb[2] << ", nb3=" << src3->nb[3];
9449 }
9450 std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
9451 std::cerr << "), " << ggml_op_name(op) << ")");
9452 GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
9453 GGML_ASSERT(dst->buffer != nullptr);
9454 const uint64_t ne00 = src0->ne[0];
9455 const uint64_t ne01 = src0->ne[1];
9456 const uint64_t ne02 = src0->ne[2];
9457 const uint64_t ne03 = src0->ne[3];
9458
9459 const bool use_src1 = src1 != nullptr;
9460 const uint64_t ne10 = use_src1 ? src1->ne[0] : 0;
9461 const uint64_t ne11 = use_src1 ? src1->ne[1] : 0;
9462 const uint64_t ne12 = use_src1 ? src1->ne[2] : 0;
9463 const uint64_t ne13 = use_src1 ? src1->ne[3] : 0;
9464
9465 const bool use_src2 = src2 != nullptr;
9466 const bool use_src3 = src3 != nullptr;
9467
9468 init_pushconst_fastdiv(pc);
9469
9470 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
9471
9472 if (pipeline == nullptr) {
9473 std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type);
9474 if (src1 != nullptr) {
9475 std::cerr << " and " << ggml_type_name(src1->type);
9476 }
9477 std::cerr << " to " << ggml_type_name(dst->type) << std::endl;
9478 GGML_ABORT("fatal error");
9479 }
9480
9481 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9482
9483 vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
9484 vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
9485 vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
9486 vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
9487 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
9488
9489 // Compute misalignment offset for descriptors and store it in in push constants.
9490 init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
9491
9492 std::array<uint32_t, 3> elements;
9493
9494 switch (op) {
9495 case GGML_OP_NORM:
9496 case GGML_OP_RMS_NORM_BACK:
9497 case GGML_OP_L2_NORM:
9498 case GGML_OP_SOFT_MAX:
9499 case GGML_OP_SOFT_MAX_BACK:
9500 case GGML_OP_SUM_ROWS:
9501 case GGML_OP_CUMSUM:
9502 case GGML_OP_MEAN:
9503 case GGML_OP_ARGMAX:
9504 {
9505 const uint32_t nr = ggml_nrows(src0);
9506 if (nr > 262144) {
9507 elements = { 512, 512, CEIL_DIV(nr, 262144) };
9508 } else if (nr > 512) {
9509 elements = { 512, CEIL_DIV(nr, 512), 1 };
9510 } else {
9511 elements = { nr, 1, 1 };
9512 }
9513 } break;
9514 case GGML_OP_SOLVE_TRI:
9515 {
9516 uint32_t nr = (uint32_t)(ne02 * ne03);
9517 if (nr > 262144) {
9518 elements = { 512, 512, CEIL_DIV(nr, 262144) };
9519 } else if (nr > 512) {
9520 elements = { 512, CEIL_DIV(nr, 512), 1 };
9521 } else {
9522 elements = { nr, 1, 1 };
9523 }
9524 }
9525 break;
9526 case GGML_OP_RMS_NORM:
9527 if (ctx->do_add_rms_partials) {
9528 // Run one element per thread, 128 threads per workgroup
9529 elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 };
9530 } else {
9531 elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
9532 }
9533 break;
9534
9535 case GGML_OP_SUM:
9536 // We use GGML_OP_SUM_ROWS with 1 row.
9537 elements = { 1, 1, 1 };
9538 break;
9539 case GGML_OP_GROUP_NORM:
9540 {
9541 const uint32_t num_groups = dst->op_params[0];
9542 elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 };
9543 } break;
9544 case GGML_OP_DIAG_MASK_INF:
9545 elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
9546 break;
9547 case GGML_OP_ROPE:
9548 case GGML_OP_ROPE_BACK:
9549 {
9550 uint32_t nrows = (uint32_t)ggml_nrows(src0);
9551 uint32_t z = 1;
9552 if (nrows > ctx->device->properties.limits.maxComputeWorkGroupCount[0]) {
9553 z = CEIL_DIV(nrows, 32768);
9554 nrows = 32768;
9555 }
9556 elements = { nrows, (uint32_t)ne00, z };
9557
9558 } break;
9559 case GGML_OP_GET_ROWS:
9560 elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) };
9561 elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9562 elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
9563 break;
9564 case GGML_OP_ARGSORT:
9565 GGML_ASSERT(0);
9566 break;
9567 case GGML_OP_IM2COL:
9568 {
9569 const bool is_2D = dst->op_params[6] == 1;
9570
9571 const uint32_t IC = src1->ne[is_2D ? 2 : 1];
9572
9573 const uint32_t KH = is_2D ? src0->ne[1] : 1;
9574 const uint32_t KW = src0->ne[0];
9575
9576 const uint32_t OH = is_2D ? dst->ne[2] : 1;
9577 const uint32_t OW = dst->ne[1];
9578
9579 const uint32_t batch = src1->ne[is_2D ? 3 : 2];
9580
9581 elements = { OW * KW * KH, OH, batch * IC };
9582 elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9583 elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
9584 } break;
9585 case GGML_OP_IM2COL_3D:
9586 {
9587 const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
9588
9589 const uint32_t N = ne13 / IC;
9590
9591 const uint32_t KD = ne02;
9592 const uint32_t KH = ne01;
9593 const uint32_t KW = ne00;
9594
9595 const uint32_t OD = dst->ne[3] / N;
9596 const uint32_t OH = dst->ne[2];
9597 const uint32_t OW = dst->ne[1];
9598
9599 const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
9600 const uint32_t N_OD_OH = N*OD*OH;
9601
9602 elements = { IC_KD_KH_KW, OW, N_OD_OH };
9603 elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
9604 } break;
9605 case GGML_OP_TIMESTEP_EMBEDDING:
9606 {
9607 const uint32_t dim = dst->op_params[0];
9608 uint32_t half_ceil = (dim + 1) / 2;
9609 elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
9610 } break;
9611 case GGML_OP_CONV_TRANSPOSE_1D:
9612 {
9613 elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
9614 } break;
9615 case GGML_OP_POOL_2D:
9616 {
9617 const uint32_t N = dst->ne[3];
9618 const uint32_t OC = dst->ne[2];
9619 const uint32_t OH = dst->ne[1];
9620 const uint32_t OW = dst->ne[0];
9621 elements = { N * OC * OH * OW, 1, 1};
9622 } break;
9623 case GGML_OP_CONV_2D:
9624 case GGML_OP_CONV_TRANSPOSE_2D:
9625 if constexpr (std::is_same_v<PC, vk_op_conv2d_push_constants>) {
9626 const uint32_t NPQ = pc.N * pc.OH * pc.OW;
9627 const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.Cout, NPQ);
9628 const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ);
9629
9630 elements = { pc.Cout, NPQ_blocks, 1 };
9631 if (elements[1] > 512) {
9632 elements[2] = CEIL_DIV(elements[1], 512);
9633 elements[1] = 512;
9634 }
9635 } else {
9636 GGML_ABORT("invalid push constant type for CONV_2D");
9637 }
9638 break;
9639 case GGML_OP_ADD:
9640 case GGML_OP_SUB:
9641 case GGML_OP_DIV:
9642 case GGML_OP_MUL:
9643 case GGML_OP_ADD1:
9644 case GGML_OP_ARANGE:
9645 case GGML_OP_FILL:
9646 case GGML_OP_SCALE:
9647 case GGML_OP_SQR:
9648 case GGML_OP_SQRT:
9649 case GGML_OP_SIN:
9650 case GGML_OP_COS:
9651 case GGML_OP_LOG:
9652 case GGML_OP_TRI:
9653 case GGML_OP_DIAG:
9654 case GGML_OP_CLAMP:
9655 case GGML_OP_PAD:
9656 case GGML_OP_ROLL:
9657 case GGML_OP_REPEAT:
9658 case GGML_OP_REPEAT_BACK:
9659 case GGML_OP_CPY:
9660 case GGML_OP_CONCAT:
9661 case GGML_OP_UPSCALE:
9662 case GGML_OP_UNARY:
9663 case GGML_OP_GLU:
9664 case GGML_OP_CONV_2D_DW:
9665 {
9666 uint32_t ne = ggml_nelements(dst);
9667 if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
9668 // Convert from number of logical elements to 2- or 4-byte units.
9669 ne /= ggml_blck_size(src0->type);
9670 if ((ggml_type_size(src0->type) % 4) == 0) {
9671 ne *= ggml_type_size(src0->type) / 4;
9672 } else {
9673 ne *= ggml_type_size(src0->type) / 2;
9674 }
9675 }
9676 // copy_to_quant has block size of 32, and each thread does QUANT_K elements.
9677 // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements.
9678 // So divide by block size here before splitting into 512x512 groups.
9679 if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
9680 ne = CEIL_DIV(ne, ggml_blck_size(dst->type));
9681 }
9682 if (ne > 262144) {
9683 elements = { 512, 512, CEIL_DIV(ne, 262144) };
9684 } else if (ne > 512) {
9685 elements = { 512, CEIL_DIV(ne, 512), 1 };
9686 } else {
9687 elements = { ne, 1, 1 };
9688 }
9689
9690 if (pipeline == ctx->device->pipeline_cpy_transpose_32 ||
9691 pipeline == ctx->device->pipeline_cpy_transpose_16) {
9692 // 32x32 tiles
9693 elements[0] = (uint32_t)CEIL_DIV(dst->ne[0], 32);
9694 elements[1] = (uint32_t)CEIL_DIV(dst->ne[1], 32);
9695 elements[2] = (uint32_t)(dst->ne[2]*dst->ne[3]);
9696 elements[0] = std::min(elements[0], ctx->device->properties.limits.maxComputeWorkGroupCount[0]);
9697 elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9698 elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
9699 }
9700 } break;
9701 case GGML_OP_ADD_ID:
9702 {
9703 elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
9704 } break;
9705 case GGML_OP_SET_ROWS:
9706 {
9707 uint32_t ne = ggml_nelements(src0);
9708 if (ggml_is_quantized(dst->type)) {
9709 // quants run 32 threads each doing QUANT_K elements
9710 ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type));
9711 } else {
9712 // scalar types do one element per thread, running 512 threads
9713 ne = CEIL_DIV(ne, 512);
9714 }
9715 if (ne > 262144) {
9716 elements = { 512, 512, CEIL_DIV(ne, 262144) };
9717 } else if (ne > 512) {
9718 elements = { 512, CEIL_DIV(ne, 512), 1 };
9719 } else {
9720 elements = { ne, 1, 1 };
9721 }
9722 }
9723 break;
9724 case GGML_OP_SSM_CONV:
9725 {
9726 const uint32_t nr = src0->ne[1];
9727 const uint32_t n_t = dst->ne[1];
9728 const uint32_t n_s = dst->ne[2];
9729 elements = { nr, n_t, n_s };
9730 }
9731 break;
9732 default:
9733 elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
9734 break;
9735 }
9736
9737 if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
9738 vk_subbuffer a_buf = src0_buf;
9739 if (ctx->do_add_rms_partials) {
9740 a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
9741 }
9742 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9743 { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
9744 } else if (op == GGML_OP_GLU) {
9745 // Empty src1 is possible in glu, but the shader needs a buffer
9746 vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
9747 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
9748 } else if (op == GGML_OP_SOFT_MAX) {
9749 // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
9750 vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
9751 vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
9752 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
9753 } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
9754 // Empty src2 and src3 is possible in rope, but the shader needs a buffer
9755 vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
9756 vk_subbuffer subbuf3 = use_src3 ? src3_buf : src0_buf;
9757 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf, subbuf3 }, pc, elements);
9758 } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
9759 if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
9760 // buffer device address path doesn't use dst buffer
9761 dst_buf.size = 1;
9762 }
9763 // im2col uses only src1 and dst buffers
9764 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
9765 } else if (op == GGML_OP_COUNT_EQUAL) {
9766 // count_equal assumes that destination buffer is initialized with zeroes
9767 ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset, 0, dst_buf.size);
9768 ggml_vk_sync_buffers(ctx, subctx);
9769 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
9770 } else if (op == GGML_OP_OPT_STEP_SGD) {
9771 // OPT_STEP_SGD works on src0, it does not need dst
9772 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
9773 } else if (use_src3) {
9774 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, src3_buf, dst_buf }, pc, elements);
9775 } else if (use_src2) {
9776 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
9777 } else if (use_src1) {
9778 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
9779 } else {
9780 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
9781 }
9782}
9783
9784static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9785 const uint32_t src0_type_size = ggml_type_size(src0->type);
9786 const uint32_t src1_type_size = ggml_type_size(src1->type);
9787 const uint32_t dst_type_size = ggml_type_size(dst->type);
9788
9789 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS, {
9790 (uint32_t)ggml_nelements(src0),
9791 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9792 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9793 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9794 0,
9795 0.0f, 0.0f, 0,
9796 });
9797}
9798
9799static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9800 const uint32_t src0_type_size = ggml_type_size(src0->type);
9801 const uint32_t src1_type_size = ggml_type_size(src1->type);
9802 const uint32_t dst_type_size = ggml_type_size(dst->type);
9803
9804 int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
9805 int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
9806 // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
9807 int offset = dst->op_params[3] / 4; // offset in bytes
9808
9809 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
9810 (uint32_t)ggml_nelements(src0),
9811 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
9812 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9813 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
9814 0,
9815 0.0f, 0.0f, offset,
9816 });
9817}
9818
9819static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
9820 const ggml_tensor *first_node = cgraph->nodes[node_idx];
9821 const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
9822
9823 // Make a list of all the tensors used by the op.
9824 // Last element of the list is the dest tensor.
9825 const ggml_tensor *tensors[MAX_PARAMETER_COUNT];
9826 uint32_t num_srcs = ctx->num_additional_fused_ops + 2;
9827 uint32_t num_tensors = num_srcs + 1;
9828 GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT);
9829
9830 tensors[0] = first_node->src[0];
9831 tensors[1] = first_node->src[1];
9832 for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) {
9833 // check whether the previous result is src[0] or src[1]
9834 if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) {
9835 tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1];
9836 } else {
9837 tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0];
9838 }
9839 }
9840 tensors[num_srcs] = dst;
9841
9842 vk_op_multi_add_push_constants pc;
9843 pc.ne20 = (uint32_t)dst->ne[0];
9844 pc.ne21 = (uint32_t)dst->ne[1];
9845 pc.ne22 = (uint32_t)dst->ne[2];
9846 pc.ne23 = (uint32_t)dst->ne[3];
9847
9848 for (uint32_t i = 0; i < num_tensors; ++i) {
9849 const ggml_tensor *t = tensors[i];
9850 pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float);
9851 pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float);
9852 pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float);
9853 pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float);
9854 }
9855 pc.rms_partials = ctx->do_add_rms_partials;
9856
9857 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op);
9858
9859 if (pipeline == nullptr) {
9860 std::cerr << "ggml_vulkan: Error: Missing multi_add";
9861 GGML_ABORT("fatal error");
9862 }
9863
9864 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9865
9866 ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT];
9867 vk_buffer buf[MAX_PARAMETER_COUNT];
9868 size_t offset[MAX_PARAMETER_COUNT];
9869 bool uma[MAX_PARAMETER_COUNT];
9870
9871 for (uint32_t i = 0; i < num_tensors; ++i) {
9872 buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
9873 buf[i] = nullptr;
9874 offset[i] = 0;
9875 uma[i] = false;
9876
9877 if (ctx->device->uma) {
9878 ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
9879 uma[i] = buf[i] != nullptr;
9880 }
9881 if (!uma[i]) {
9882 buf[i] = buf_ctx[i]->dev_buffer;
9883 offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
9884 }
9885 GGML_ASSERT(buf[i] != nullptr);
9886 }
9887 // If any remaining descriptors are unused, just point them at src[0]
9888 for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) {
9889 buf[i] = buf[0];
9890 offset[i] = 0;
9891 }
9892 if (ctx->do_add_rms_partials) {
9893 buf[num_tensors] = ctx->prealloc_add_rms_partials;
9894 offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset;
9895 }
9896
9897 std::array<uint32_t, 3> elements;
9898
9899 uint32_t ne = ggml_nelements(dst);
9900 if (ne > 262144) {
9901 elements = { 512, 512, CEIL_DIV(ne, 262144) };
9902 } else if (ne > 512) {
9903 elements = { 512, CEIL_DIV(ne, 512), 1 };
9904 } else {
9905 elements = { ne, 1, 1 };
9906 }
9907
9908 static_assert(MAX_PARAMETER_COUNT == 12);
9909 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9910 {
9911 ggml_vk_subbuffer(ctx, buf[0], offset[0]),
9912 ggml_vk_subbuffer(ctx, buf[1], offset[1]),
9913 ggml_vk_subbuffer(ctx, buf[2], offset[2]),
9914 ggml_vk_subbuffer(ctx, buf[3], offset[3]),
9915 ggml_vk_subbuffer(ctx, buf[4], offset[4]),
9916 ggml_vk_subbuffer(ctx, buf[5], offset[5]),
9917 ggml_vk_subbuffer(ctx, buf[6], offset[6]),
9918 ggml_vk_subbuffer(ctx, buf[7], offset[7]),
9919 ggml_vk_subbuffer(ctx, buf[8], offset[8]),
9920 ggml_vk_subbuffer(ctx, buf[9], offset[9]),
9921 ggml_vk_subbuffer(ctx, buf[10], offset[10]),
9922 ggml_vk_subbuffer(ctx, buf[11], offset[11]),
9923 }, pc, elements);
9924}
9925
9926static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9927 const uint32_t src0_type_size = ggml_type_size(src0->type);
9928 const uint32_t src1_type_size = ggml_type_size(src1->type);
9929 const uint32_t dst_type_size = ggml_type_size(dst->type);
9930
9931 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD, {
9932 (uint32_t)ggml_nelements(src0),
9933 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9934 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9935 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9936 0,
9937 0.0f, 0.0f, ctx->do_add_rms_partials,
9938 });
9939}
9940
9941static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9942 const uint32_t src0_type_size = ggml_type_size(src0->type);
9943 const uint32_t src1_type_size = ggml_type_size(src1->type);
9944 const uint32_t dst_type_size = ggml_type_size(dst->type);
9945
9946 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SUB, {
9947 (uint32_t)ggml_nelements(src0),
9948 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9949 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9950 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9951 0,
9952 0.0f, 0.0f, 0,
9953 });
9954}
9955
9956static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9957 const uint32_t src0_type_size = ggml_type_size(src0->type);
9958 const uint32_t src1_type_size = ggml_type_size(src1->type);
9959 const uint32_t dst_type_size = ggml_type_size(dst->type);
9960
9961 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_MUL, {
9962 (uint32_t)ggml_nelements(src0),
9963 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9964 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9965 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9966 0,
9967 0.0f, 0.0f, 0,
9968 });
9969}
9970
9971static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
9972 const uint32_t src0_type_size = ggml_type_size(src0->type);
9973 const uint32_t src1_type_size = ggml_type_size(src1->type);
9974 const uint32_t dst_type_size = ggml_type_size(dst->type);
9975
9976 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_DIV, {
9977 (uint32_t)ggml_nelements(src0),
9978 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
9979 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9980 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
9981 0,
9982 0.0f, 0.0f, 0,
9983 });
9984}
9985
9986static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
9987 const uint32_t src0_type_size = ggml_type_size(src0->type);
9988 const uint32_t src1_type_size = ggml_type_size(src1->type);
9989 const uint32_t src2_type_size = ggml_type_size(src2->type);
9990
9991 ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_ADD_ID, {
9992 (uint32_t)dst->ne[0],
9993 (uint32_t)dst->ne[1],
9994 (uint32_t)src0->nb[1] / src0_type_size,
9995 (uint32_t)src0->nb[2] / src0_type_size,
9996 (uint32_t)src1->nb[1] / src1_type_size,
9997 (uint32_t)src2->nb[1] / src2_type_size,
9998 });
9999}
10000
10001static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version) {
10002 GGML_ASSERT(version == 6 || version == 7);
10003 int num_srcs = version == 6 ? 6 : 7;
10004
10005 for (int i = 0; i < num_srcs; i++) {
10006 GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
10007 }
10008
10009 GGML_ASSERT(dst->buffer != nullptr);
10010
10011 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
10012 GGML_ASSERT(pipeline != nullptr);
10013
10014 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10015
10016 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10017 vk_subbuffer src_buf[7] = {};
10018 for (int i = 0; i < num_srcs; i++) {
10019 src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
10020 }
10021
10022 std::array<uint32_t, 3> elements = {
10023 (uint32_t)(pc.B * pc.H),
10024 1,
10025 1
10026 };
10027
10028 if (version == 6) {
10029 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10030 {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
10031 pc, elements);
10032 } else if (version == 7) {
10033 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10034 {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
10035 pc, elements);
10036 } else {
10037 // shouldn't happen
10038 GGML_ASSERT(false);
10039 }
10040}
10041
10042static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10043 const size_t seq_length = dst->src[0]->ne[2];
10044 const size_t n_embed = dst->ne[0];
10045 const size_t n_heads = dst->src[0]->ne[1];
10046 const size_t n_seqs = dst->src[5]->ne[1];
10047
10048 ggml_vk_op_f32_wkv(
10049 ctx, subctx, dst,
10050 {
10051 (uint32_t)n_seqs,
10052 (uint32_t)seq_length,
10053 (uint32_t)n_embed,
10054 (uint32_t)n_heads,
10055 },
10056 6
10057 );
10058}
10059
10060static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10061 const size_t seq_length = dst->src[0]->ne[2];
10062 const size_t n_embed = dst->ne[0];
10063 const size_t n_heads = dst->src[0]->ne[1];
10064 const size_t n_seqs = dst->src[6]->ne[1];
10065
10066 ggml_vk_op_f32_wkv(
10067 ctx, subctx, dst,
10068 {
10069 (uint32_t)n_seqs,
10070 (uint32_t)seq_length,
10071 (uint32_t)n_embed,
10072 (uint32_t)n_heads,
10073 },
10074 7
10075 );
10076}
10077
10078static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10079 const ggml_tensor * src0 = dst->src[0];
10080 const ggml_tensor * src1 = dst->src[1];
10081 const ggml_tensor * src2 = dst->src[2];
10082 const ggml_tensor * src3 = dst->src[3];
10083 const ggml_tensor * src4 = dst->src[4];
10084 const ggml_tensor * src5 = dst->src[5];
10085
10086 GGML_ASSERT(dst->buffer != nullptr);
10087
10088 const uint32_t head_dim = src0->ne[1];
10089 const uint32_t n_head = src1->ne[1];
10090 const uint32_t n_group = src4->ne[1];
10091 const uint32_t n_tok = src1->ne[2];
10092 const uint32_t n_seq = src1->ne[3];
10093
10094 bool is_mamba2 = (src3->nb[1] == sizeof(float));
10095 GGML_ASSERT(is_mamba2);
10096
10097 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
10098 GGML_ASSERT(pipeline != nullptr);
10099
10100 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10101
10102 const int64_t s_off = ggml_nelements(src1) * sizeof(float);
10103
10104 const vk_op_ssm_scan_push_constants pc = {
10105 (uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
10106 (uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
10107 (uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
10108 (uint32_t)src3->nb[1],
10109 (uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
10110 (uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
10111 (uint32_t)s_off,
10112 n_head, head_dim, n_group, n_tok
10113 };
10114
10115 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10116 vk_subbuffer src_buf[7] = {};
10117 for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
10118 src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
10119 }
10120
10121 std::array<uint32_t, 3> elements;
10122
10123 const uint32_t d_state = src0->ne[0];
10124 uint32_t num_subgroups = d_state / ctx->device->subgroup_size;
10125 const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups);
10126 const uint32_t num_workgroups_y = n_seq;
10127 elements = { num_workgroups_x, num_workgroups_y, 1 };
10128
10129 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10130 {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], src_buf[6], dst_buf},
10131 pc, elements);
10132}
10133
10134static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10135 const ggml_tensor * src0 = dst->src[0];
10136 const ggml_tensor * src1 = dst->src[1];
10137
10138 ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
10139 (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
10140 (uint32_t)src1->nb[1],
10141 (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
10142 (uint32_t)src1->ne[0],
10143 (uint32_t)src0->ne[0],
10144 (uint32_t)src0->ne[1],
10145 (uint32_t)dst->ne[1],
10146 (uint32_t)dst->ne[2],
10147 });
10148}
10149
10150static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc) {
10151 const ggml_tensor * x = dst->src[0];
10152 const ggml_tensor * g = dst->src[1];
10153 const ggml_tensor * gm = dst->src[2];
10154 const ggml_tensor * gv = dst->src[3];
10155 const ggml_tensor * p = dst->src[4];
10156
10157 GGML_ASSERT(x->type == GGML_TYPE_F32);
10158 GGML_ASSERT(g->type == GGML_TYPE_F32);
10159 GGML_ASSERT(gm->type == GGML_TYPE_F32);
10160 GGML_ASSERT(gv->type == GGML_TYPE_F32);
10161 GGML_ASSERT(p->type == GGML_TYPE_F32);
10162 GGML_ASSERT(dst->buffer != nullptr);
10163 GGML_ASSERT(ggml_is_contiguous(x));
10164 GGML_ASSERT(ggml_is_contiguous(g));
10165 GGML_ASSERT(ggml_is_contiguous(gm));
10166 GGML_ASSERT(ggml_is_contiguous(gv));
10167 GGML_ASSERT(ggml_is_contiguous(p));
10168 GGML_ASSERT(ggml_are_same_shape(x, g));
10169 GGML_ASSERT(ggml_are_same_shape(x, gm));
10170 GGML_ASSERT(ggml_are_same_shape(x, gv));
10171 GGML_ASSERT(ggml_nelements(p) == 7);
10172
10173 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
10174 GGML_ASSERT(pipeline != nullptr);
10175
10176 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10177
10178 vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
10179 vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);
10180 vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);
10181 vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);
10182 vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);
10183
10184 std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
10185
10186 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10187 {x_buf, g_buf, gm_buf, gv_buf, p_buf},
10188 pc, elements);
10189}
10190
10191static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10192 const size_t n = ggml_nelements(dst->src[0]);
10193
10194 ggml_vk_op_f32_opt_step_adamw(
10195 ctx, subctx, dst,
10196 { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f }
10197 );
10198}
10199
10200static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
10201 const size_t n = ggml_nelements(dst->src[0]);
10202
10203 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f, 0.0f, 0.0f });
10204}
10205
10206static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10207 int * op_params = (int *)dst->op_params;
10208
10209 const uint32_t src0_type_size = ggml_type_size(src0->type);
10210 const uint32_t src1_type_size = ggml_type_size(src1->type);
10211 const uint32_t dst_type_size = ggml_type_size(dst->type);
10212
10213 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONCAT, {
10214 (uint32_t)ggml_nelements(dst),
10215 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10216 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10217 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
10218 0,
10219 0.0f, 0.0f, op_params[0],
10220 });
10221}
10222
10223static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10224 const uint32_t src0_type_size = ggml_type_size(src0->type);
10225 const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0);
10226
10227 GGML_TENSOR_UNARY_OP_LOCALS
10228
10229 float sf0 = (float)ne0 / ne00;
10230 float sf1 = (float)ne1 / ne01;
10231 float sf2 = (float)ne2 / ne02;
10232 float sf3 = (float)ne3 / ne03;
10233 float pixel_offset = 0.5f;
10234
10235 if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) {
10236 sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0;
10237 sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1;
10238 pixel_offset = 0.0f;
10239 }
10240
10241 ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
10242 (uint32_t)ggml_nelements(dst), 0, 0,
10243 (uint32_t)ne00, (uint32_t)ne01,
10244 (uint32_t)nb00 / src0_type_size, (uint32_t)nb01 / src0_type_size, (uint32_t)nb02 / src0_type_size, (uint32_t)nb03 / src0_type_size,
10245 (uint32_t)ne0, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
10246 sf0, sf1, sf2, sf3, pixel_offset
10247 });
10248}
10249
10250static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10251 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10252 p.param1 = ggml_get_op_params_f32(dst, 0);
10253 p.param2 = ggml_get_op_params_f32(dst, 1);
10254
10255 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p));
10256}
10257
10258static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10259 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst));
10260}
10261
10262static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10263 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst));
10264}
10265
10266static void ggml_vk_add1(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10267 const uint32_t src0_type_size = ggml_type_size(src0->type);
10268 const uint32_t src1_type_size = ggml_type_size(src1->type);
10269 const uint32_t dst_type_size = ggml_type_size(dst->type);
10270
10271 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ADD1, {
10272 (uint32_t)ggml_nelements(src0),
10273 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10274 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10275 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
10276 0,
10277 0.0f, 0.0f, 0,
10278 });
10279}
10280
10281static void ggml_vk_arange(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10282 VK_LOG_DEBUG("ggml_vk_arange(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
10283
10284 vk_op_push_constants pc = {
10285 (uint32_t)ggml_nelements(dst),
10286 1,
10287 ggml_get_op_params_f32(dst, 0),
10288 ggml_get_op_params_f32(dst, 2),
10289 0.0f, 0.0f,
10290 };
10291
10292 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_ARANGE);
10293 GGML_ASSERT(pipeline != nullptr);
10294
10295 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10296 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
10297
10298 std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
10299
10300 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
10301}
10302
10303static void ggml_vk_fill(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10304 VK_LOG_DEBUG("ggml_vk_fill(dst=" << dst << ", ne=" << ggml_nelements(dst) << ")");
10305
10306 vk_op_push_constants pc = {
10307 (uint32_t)ggml_nelements(dst),
10308 1,
10309 ggml_get_op_params_f32(dst, 0),
10310 0.0f,
10311 0.0f, 0.0f,
10312 };
10313
10314 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, dst, GGML_OP_FILL);
10315 GGML_ASSERT(pipeline != nullptr);
10316
10317 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10318 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, false);
10319
10320 std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(dst), 1, 1 };
10321
10322 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { dst_buf }, pc, elements);
10323}
10324
10325static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10326 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst));
10327}
10328
10329static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10330 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
10331}
10332
10333static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10334 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
10335}
10336
10337static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10338 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10339 p.param1 = ggml_get_op_params_f32(dst, 0);
10340
10341 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
10342}
10343
10344static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10345 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
10346
10347 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
10348}
10349
10350static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10351 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10352 p.param1 = ggml_get_op_params_f32(dst, 0);
10353 p.param2 = ggml_get_op_params_f32(dst, 1);
10354
10355 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p));
10356}
10357
10358static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10359 vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
10360 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p));
10361}
10362
10363static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10364 const int32_t s0 = ggml_get_op_params_i32(dst, 0);
10365 const int32_t s1 = ggml_get_op_params_i32(dst, 1);
10366 const int32_t s2 = ggml_get_op_params_i32(dst, 2);
10367 const int32_t s3 = ggml_get_op_params_i32(dst, 3);
10368 const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000);
10369 const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000);
10370
10371 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10372 memcpy(&p.param1, &s01_packed, sizeof(float));
10373 memcpy(&p.param2, &s23_packed, sizeof(float));
10374
10375 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p));
10376}
10377
10378static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10379 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
10380 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p));
10381}
10382
10383static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10384 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
10385 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p));
10386}
10387
10388static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10389 uint32_t ne = (uint32_t)ggml_nelements(src0);
10390 if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
10391 // Convert from number of logical elements to 2- or 4-byte units.
10392 ne /= ggml_blck_size(src0->type);
10393 if ((ggml_type_size(src0->type) % 4) == 0) {
10394 ne *= ggml_type_size(src0->type) / 4;
10395 } else {
10396 ne *= ggml_type_size(src0->type) / 2;
10397 }
10398 }
10399
10400 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne);
10401 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p));
10402}
10403
10404static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10405 const uint32_t src0_type_size = ggml_type_size(src0->type);
10406 const uint32_t src1_type_size = ggml_type_size(src1->type);
10407 const uint32_t dst_type_size = ggml_type_size(dst->type);
10408
10409 // Skip empty skip_rows operations. For most ops the empty check at the start
10410 // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst
10411 // with empty srcs.
10412 if (ggml_is_empty(src0) || ggml_is_empty(src1)) {
10413 return;
10414 }
10415
10416 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SET_ROWS, {
10417 (uint32_t)ggml_nelements(src0),
10418 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10419 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10420 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
10421 0,
10422 0.0f, 0.0f, 0,
10423 });
10424}
10425
10426static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10427 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10428}
10429
10430static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10431 float * op_params = (float *)dst->op_params;
10432
10433 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10434}
10435
10436static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10437 const int * int_op_params = (const int *)dst->op_params;
10438 const float * float_op_params = (const float *)dst->op_params;
10439
10440 const uint32_t num_groups = int_op_params[0];
10441 const float eps = float_op_params[1];
10442 const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
10443
10444 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f, 0.0f, 0.0f });
10445}
10446
10447static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
10448 const uint32_t ne = (uint32_t)node->ne[0];
10449 const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0];
10450 const uint32_t num_partials = CEIL_DIV(ne, denom);
10451 return num_partials;
10452}
10453
10454static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) {
10455 const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node);
10456 const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment);
10457 return num_bytes;
10458}
10459
10460static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *dst, const ggml_tensor *src0, const bool has_ff, bool backprop, const uint32_t set_rows_stride) {
10461 const int n_dims = ((const int32_t *) dst->op_params)[1];
10462 const int mode = ((const int32_t *) dst->op_params)[2];
10463 // const int n_ctx = ((const int32_t *) dst->op_params)[3];
10464 const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
10465 const float freq_base = ((const float *) dst->op_params)[5];
10466 const float freq_scale = ((const float *) dst->op_params)[6];
10467 const float ext_factor = ((const float *) dst->op_params)[7];
10468 const float attn_factor = ((const float *) dst->op_params)[8];
10469 const float beta_fast = ((const float *) dst->op_params)[9];
10470 const float beta_slow = ((const float *) dst->op_params)[10];
10471 int sections[4] {};
10472 if (mode & GGML_ROPE_TYPE_MROPE) {
10473 memcpy(sections, (const int32_t *) dst->op_params + 11, sizeof(int)*4);
10474 }
10475
10476 const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
10477
10478 float corr_dims[2];
10479 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
10480
10481 const float theta_scale = powf(freq_base, -2.0f/n_dims);
10482
10483 uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
10484 uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
10485 uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
10486
10487 uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
10488 uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
10489 uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
10490
10491 vk_op_rope_push_constants rope {
10492 (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
10493 freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
10494 { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
10495
10496 (uint32_t)src0->ne[0],
10497 (uint32_t)src0->ne[1],
10498 (uint32_t)src0->ne[2],
10499 nb01, nb02, nb03,
10500 nb11, nb12, nb13,
10501 };
10502
10503 return rope;
10504}
10505
10506static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx, float * op_params) {
10507 ggml_tensor * dst;
10508 const ggml_tensor * src0;
10509 const ggml_tensor * src1;
10510
10511 if (ctx->num_additional_fused_ops > 0) {
10512 // fused rms_norm + mul
10513 ggml_tensor *mul = cgraph->nodes[node_idx + 1];
10514 ggml_tensor *other_src = mul->src[0] == cgraph->nodes[node_idx + 0] ? mul->src[1] : mul->src[0];
10515 dst = mul;
10516 src0 = cgraph->nodes[node_idx]->src[0];
10517 src1 = other_src;
10518 } else {
10519 dst = cgraph->nodes[node_idx];
10520 src0 = src1 = dst->src[0];
10521 }
10522
10523 const uint32_t src0_type_size = ggml_type_size(src0->type);
10524 const uint32_t src1_type_size = ggml_type_size(src1->type);
10525 const uint32_t dst_type_size = ggml_type_size(dst->type);
10526
10527 uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0;
10528
10529 vk_op_binary_push_constants bin {
10530 (uint32_t)ggml_nelements(src0),
10531 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
10532 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
10533 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
10534 0,
10535 op_params[0], 0.0f, (int32_t)param3,
10536 };
10537
10538 // more than one fused op means rms_norm+mul+rope
10539 if (ctx->num_additional_fused_ops > 1) {
10540 static constexpr uint32_t max_tensors = 7;
10541 const ggml_tensor *tensors[max_tensors] {};
10542
10543 ggml_tensor *rms = cgraph->nodes[node_idx + 0];
10544 ggml_tensor *mul = cgraph->nodes[node_idx + 1];
10545 ggml_tensor *rope = cgraph->nodes[node_idx + 2];
10546
10547 ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
10548
10549 bool do_set_rows = ctx->num_additional_fused_ops == 4;
10550
10551 tensors[0] = rms->src[0];
10552 tensors[1] = other_src;
10553 tensors[2] = mul;
10554 tensors[3] = rope->src[1]; // pos
10555 tensors[4] = rope->src[2]; // ff
10556 tensors[5] = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; // dst
10557 tensors[6] = do_set_rows ? tensors[5]->src[1] : nullptr;
10558 const uint32_t set_rows_stride = do_set_rows ? tensors[5]->nb[1] / ggml_type_size(tensors[5]->type) : 0;
10559
10560 vk_op_rms_norm_mul_rope_push_constants pc;
10561 pc.bin = bin;
10562 pc.rope = ggml_vk_make_rope_constants(rope, rope->src[0], tensors[4] != nullptr, false, set_rows_stride);
10563
10564 vk_pipeline pipeline = tensors[5]->type == GGML_TYPE_F16 ? ctx->device->pipeline_rms_norm_mul_rope_f32_f16 : ctx->device->pipeline_rms_norm_mul_rope_f32_f32;
10565
10566 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10567
10568 ggml_backend_vk_buffer_context * buf_ctx[max_tensors];
10569 vk_buffer buf[max_tensors];
10570 size_t offset[max_tensors];
10571 bool uma[max_tensors];
10572
10573 for (uint32_t i = 0; i < max_tensors; ++i) {
10574 if (!tensors[i]) {
10575 // If any remaining descriptors are unused, just point them at src[0]
10576 buf[i] = buf[0];
10577 offset[i] = 0;
10578 continue;
10579 }
10580 buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context;
10581 buf[i] = nullptr;
10582 offset[i] = 0;
10583 uma[i] = false;
10584
10585 if (ctx->device->uma) {
10586 ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]);
10587 uma[i] = buf[i] != nullptr;
10588 }
10589 if (!uma[i]) {
10590 buf[i] = buf_ctx[i]->dev_buffer;
10591 offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs;
10592 }
10593 GGML_ASSERT(buf[i] != nullptr);
10594 }
10595
10596 std::array<uint32_t, 3> elements;
10597 elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
10598
10599 static_assert(max_tensors == 7);
10600 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10601 {
10602 ggml_vk_subbuffer(ctx, buf[0], offset[0]),
10603 ggml_vk_subbuffer(ctx, buf[1], offset[1]),
10604 ggml_vk_subbuffer(ctx, buf[2], offset[2]),
10605 ggml_vk_subbuffer(ctx, buf[3], offset[3]),
10606 ggml_vk_subbuffer(ctx, buf[4], offset[4]),
10607 ggml_vk_subbuffer(ctx, buf[5], offset[5]),
10608 ggml_vk_subbuffer(ctx, buf[6], offset[6]),
10609 }, pc, elements);
10610 } else {
10611 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM, std::move(bin));
10612 }
10613
10614 if (ctx->do_add_rms_partials_offset_calculation) {
10615 ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0);
10616 ctx->do_add_rms_partials = false;
10617 ctx->do_add_rms_partials_offset_calculation = false;
10618 }
10619}
10620
10621static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10622 float * op_params = (float *)dst->op_params;
10623 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10624}
10625
10626static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10627 float * op_params = (float *)dst->op_params;
10628 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10629}
10630
10631static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10632 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
10633}
10634
10635static void ggml_vk_xielu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10636 float * op_params = (float *)dst->op_params;
10637 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_UNARY,
10638 {
10639 (uint32_t)ggml_nelements(src0), 0,
10640 op_params[1], op_params[2], op_params[3], op_params[4]
10641 }
10642 );
10643}
10644
10645static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10646 const float * op_params_f = (const float *)dst->op_params;
10647
10648 const bool swapped = (bool)dst->op_params[1];
10649 const bool split = src1 != nullptr;
10650 const float alpha = op_params_f[2];
10651 const float limit = op_params_f[3];
10652
10653 GGML_ASSERT(ggml_is_contiguous(src0));
10654
10655 if (!split) {
10656 GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
10657 } else {
10658 GGML_ASSERT(src0->ne[0] == src1->ne[0]);
10659 GGML_ASSERT(src0->ne[0] == dst->ne[0]);
10660 GGML_ASSERT(src0->type == src1->type);
10661 }
10662
10663 const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
10664
10665 ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GLU,
10666 {
10667 (uint32_t)ggml_nelements(dst),
10668 (uint32_t)src0->ne[0],
10669 (uint32_t)dst->ne[0],
10670 mode,
10671 alpha,
10672 limit
10673 });
10674}
10675
10676static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10677 int32_t * op_params = (int32_t *)dst->op_params;
10678 ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] });
10679}
10680
10681static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
10682 float * op_params = (float *)dst->op_params;
10683
10684 float scale = op_params[0];
10685 float max_bias = op_params[1];
10686
10687 const uint32_t ncols = (uint32_t)src0->ne[0];
10688 const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
10689 const uint32_t nrows_y = (uint32_t)src0->ne[1];
10690
10691 const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
10692 const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
10693 const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
10694 const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
10695 const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
10696
10697 const uint32_t n_head_kv = src0->ne[2];
10698 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
10699
10700 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
10701 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
10702
10703 vk_op_soft_max_push_constants pc {
10704 ncols,
10705 src1 != nullptr ? nrows_y : (uint32_t)0,
10706 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
10707 ne12, ne13,
10708 nb11, nb12, nb13,
10709 scale, max_bias,
10710 m0, m1,
10711 n_head_log2,
10712 nrows_x,
10713 src2 != nullptr
10714 };
10715
10716 if (ncols <= 16384) {
10717 ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
10718 } else {
10719
10720 vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
10721 vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
10722 vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
10723 vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
10724
10725 uint32_t elems_per_wg = 128 * 4;
10726 uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
10727 size_t tmp_size = num_wgs * nrows_x * sizeof(float);
10728
10729 if (ctx->prealloc_size_x < tmp_size) {
10730 ctx->prealloc_size_x = tmp_size;
10731 ggml_vk_preallocate_buffers(ctx, subctx);
10732 }
10733 if (ctx->prealloc_size_y < tmp_size) {
10734 ctx->prealloc_size_y = tmp_size;
10735 ggml_vk_preallocate_buffers(ctx, subctx);
10736 }
10737 if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
10738 ggml_vk_sync_buffers(ctx, subctx);
10739 }
10740
10741 vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
10742 vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
10743
10744 std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
10745
10746 vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
10747 vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
10748 vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
10749
10750 ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
10751 ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
10752 ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
10753
10754 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10755 ggml_vk_sync_buffers(ctx, subctx);
10756 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10757 ggml_vk_sync_buffers(ctx, subctx);
10758 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10759
10760 ctx->prealloc_x_need_sync = true;
10761 ctx->prealloc_y_need_sync = true;
10762 }
10763}
10764
10765static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
10766 float * op_params = (float *)dst->op_params;
10767 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1], 0.0f, 0.0f });
10768}
10769
10770static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
10771 topk_moe_mode mode = ctx->fused_topk_moe_mode;
10772 ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
10773 ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
10774 ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
10775 ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
10776 (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
10777 cgraph->nodes[node_idx + 3];
10778
10779 GGML_ASSERT(logits->type == GGML_TYPE_F32);
10780 GGML_ASSERT(bias->type == GGML_TYPE_F32);
10781 GGML_ASSERT(weights->type == GGML_TYPE_F32);
10782 GGML_ASSERT(ids->type == GGML_TYPE_I32);
10783
10784 const int n_experts = logits->ne[0];
10785 const int n_rows = logits->ne[1];
10786 const int n_expert_used = weights->ne[1];
10787
10788 GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
10789
10790 vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
10791
10792 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10793
10794 vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
10795 vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
10796 vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
10797 vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
10798
10799 vk_op_topk_moe_push_constants pc {};
10800 pc.n_rows = n_rows;
10801 pc.n_experts_push = n_experts;
10802 pc.n_expert_used = n_expert_used;
10803 pc.clamp_min = -std::numeric_limits<float>::infinity();
10804 pc.clamp_max = std::numeric_limits<float>::infinity();
10805 if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
10806 ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
10807 GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
10808 pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
10809 pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
10810 }
10811 if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
10812 ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
10813 GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
10814 pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
10815 pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
10816 }
10817
10818#define GATING_FUNC_SOFTMAX 0
10819#define GATING_FUNC_SIGMOID 1
10820#define GATING_FUNC_SOFTMAX_WEIGHT 2
10821
10822 pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
10823 mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
10824 GATING_FUNC_SOFTMAX;
10825 pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
10826 pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
10827 if (ctx->fused_topk_moe_scale) {
10828 GGML_ASSERT(weights->op == GGML_OP_SCALE);
10829 pc.output_scale = ggml_get_op_params_f32(weights, 0);
10830 pc.output_bias = ggml_get_op_params_f32(weights, 1);
10831 } else {
10832 pc.output_scale = 1.0f;
10833 pc.output_bias = 0.0f;
10834 }
10835
10836 GGML_ASSERT(n_expert_used <= n_experts);
10837
10838 const uint32_t rows_per_block = 4;
10839 std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
10840
10841 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
10842}
10843
10844static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
10845 ggml_tensor * dst = cgraph->nodes[node_idx];
10846 const ggml_tensor * src0 = dst->src[0];
10847 const ggml_tensor * src1 = dst->src[1];
10848 const ggml_tensor * src2 = dst->src[2];
10849 const ggml_tensor * src3 = nullptr;
10850 const int n_dims = ((int32_t *) dst->op_params)[1];
10851 const int mode = ((int32_t *) dst->op_params)[2];
10852 // const int n_ctx = ((int32_t *) dst->op_params)[3];
10853 const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
10854 const float freq_base = ((float *) dst->op_params)[5];
10855 const float beta_fast = ((float *) dst->op_params)[9];
10856 const float beta_slow = ((float *) dst->op_params)[10];
10857 int sections[4] {};
10858 if (mode & GGML_ROPE_TYPE_MROPE) {
10859 memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4);
10860 }
10861
10862 float corr_dims[2];
10863 ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
10864
10865 uint32_t set_rows_stride = 0;
10866 // Fused rope + view + set_rows passes the set_rows destination stride in set_rows_stride
10867 // and overrides the dst and sets src3=row_indices
10868 if (ctx->num_additional_fused_ops > 0) {
10869 set_rows_stride = cgraph->nodes[node_idx + 2]->nb[1] / ggml_type_size(cgraph->nodes[node_idx + 2]->type);
10870 src3 = cgraph->nodes[node_idx + 2]->src[1];
10871 dst = cgraph->nodes[node_idx + 2];
10872 }
10873
10874 ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, src2, src3, dst, GGML_OP_ROPE,
10875 ggml_vk_make_rope_constants(cgraph->nodes[node_idx], src0, src2 != nullptr, backprop, set_rows_stride));
10876}
10877
10878static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10879 const uint32_t * op_params = (const uint32_t *)dst->op_params;
10880
10881 uint32_t ncols = src0->ne[0];
10882 uint32_t nrows = ggml_nrows(src0);
10883
10884 uint32_t ncols_pad_log2 = (uint32_t)ceilf(log2f(float(ncols)));
10885 uint32_t ncolsp2 = 1 << ncols_pad_log2;
10886
10887 vk_op_argsort_push_constants pc { ncols, ncolsp2, ncols_pad_log2, nrows, op_params[0], 0, 0, 0, 0, };
10888
10889 // Pick the largest workgroup size <= ncolsp2
10890 uint32_t pipeline_idx = std::min(ncols_pad_log2, num_argsort_pipelines - 1);
10891
10892 // Use the "small" argsort shader if the whole sort can be done by a single workgroup.
10893 bool use_small = ncols_pad_log2 <= ctx->device->max_workgroup_size_log2 &&
10894 ctx->device->pipeline_argsort_f32[pipeline_idx] != nullptr;
10895
10896 vk_pipeline pipeline = use_small ? ctx->device->pipeline_argsort_f32[pipeline_idx]
10897 : ctx->device->pipeline_argsort_large_f32[pipeline_idx];
10898
10899 vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0);
10900 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10901 vk_subbuffer subbuf1 = dst_buf;
10902
10903 // Reserve space for ivec2 per element, with rows padded to a power of two
10904 if (!use_small) {
10905 const size_t x_sz = size_t{ncolsp2} * nrows * 2 * sizeof(int);
10906
10907 if (ctx->prealloc_size_x < x_sz) {
10908 ctx->prealloc_size_x = x_sz;
10909 ggml_vk_preallocate_buffers(ctx, subctx);
10910 }
10911 if (ctx->prealloc_x_need_sync) {
10912 ggml_vk_sync_buffers(ctx, subctx);
10913 }
10914 subbuf1 = { ctx->prealloc_x, 0, ctx->prealloc_x->size };
10915 }
10916
10917 std::array<uint32_t, 3> elements;
10918
10919 elements[0] = ncolsp2;
10920 elements[1] = std::min((uint32_t)ggml_nrows(src0), ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
10921 elements[2] = 1;
10922
10923 // First dispatch initializes tmp_idx and does the first N passes where
10924 // there is only communication between threads in the same workgroup.
10925 {
10926 vk_op_argsort_push_constants pc2 = pc;
10927 pc2.outer_start = 0;
10928 pc2.outer_end = std::min(ncols_pad_log2, ctx->device->max_workgroup_size_log2);
10929 pc2.inner_start = 0;
10930 pc2.inner_end = 100;
10931 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10932 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
10933 }
10934 if (!use_small) {
10935 ggml_vk_sync_buffers(ctx, subctx);
10936 // Loop over outer/inner passes, synchronizing between each pass.
10937 for (uint32_t outer = ctx->device->max_workgroup_size_log2; outer < ncols_pad_log2; ++outer) {
10938 for (uint32_t inner = 0; inner < outer + 1; ++inner) {
10939 vk_op_argsort_push_constants pc2 = pc;
10940 pc2.outer_start = outer;
10941 pc2.outer_end = outer + 1;
10942 pc2.inner_start = inner;
10943 pc2.inner_end = inner + 1;
10944 // When the inner idx is large enough, there's only communication
10945 // within a workgroup. So the remaining inner iterations can all
10946 // run in the same dispatch.
10947 if (outer - inner < pipeline_idx) {
10948 pc2.inner_end = 100;
10949 inner = outer;
10950 pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx];
10951 } else {
10952 // Smaller workgroup empirically seems to perform better
10953 pipeline = ctx->device->pipeline_argsort_large_f32[pipeline_idx - 2];
10954 }
10955 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10956 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc2, elements);
10957 ggml_vk_sync_buffers(ctx, subctx);
10958 }
10959 }
10960 ctx->prealloc_x_need_sync = true;
10961 }
10962}
10963
10964static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10965 uint32_t ncols = src0->ne[0];
10966 uint32_t nrows = ggml_nrows(src0);
10967 uint32_t k = dst->ne[0];
10968
10969 vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
10970
10971 if (ctx->prealloc_x_need_sync) {
10972 ggml_vk_sync_buffers(ctx, subctx);
10973 }
10974
10975 std::array<uint32_t, 3> elements;
10976 elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
10977 elements[2] = 1;
10978
10979 uint32_t num_elements = ncols;
10980
10981 // Each iteration reduces a workgroup's worth of elements down to the K
10982 // largest elements. Repeat until we have the top K elements.
10983 // Need to do at least one iteration to write out the results.
10984 bool done_one_iter = false;
10985 uint32_t dbl_buf_index = 0;
10986 size_t dbl_buf_size;
10987 while (num_elements > k || !done_one_iter) {
10988
10989 // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
10990 // But if K is larger, then we need a larger workgroup
10991 uint32_t max_pipeline = num_topk_pipelines - 1;
10992 uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
10993 max_pipeline = std::min(preferred_pipeline, max_pipeline);
10994 uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
10995 // require full subgroup
10996 min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
10997
10998 uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
10999 pipeline_idx = std::min(pipeline_idx, max_pipeline);
11000 pipeline_idx = std::max(pipeline_idx, min_pipeline);
11001
11002 if (num_elements > (1u << pipeline_idx)) {
11003 // If we could finish on this loop iteration (i.e. a single workgroup)
11004 // then do so. It's better than the overhead of another pass.
11005 for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
11006 if (num_elements <= (1u << i)) {
11007 pipeline_idx = i;
11008 break;
11009 }
11010 }
11011 }
11012
11013 vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
11014 // If the device doesn't support a pipeline this large, use smaller
11015 while (!pipeline) {
11016 pipeline_idx--;
11017 GGML_ASSERT(pipeline_idx >= min_pipeline);
11018 pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
11019 }
11020
11021 vk_op_topk_push_constants pc2 = pc;
11022 pc2.ncols_input = num_elements;
11023
11024 // Number of elements remaining after this pass
11025 uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
11026
11027 pc2.ncols_output = num_dst_elements;
11028
11029 if (!done_one_iter) {
11030 // Reserve space for ivec2 per element, double buffered
11031 // K per workgroup per row
11032 dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
11033 dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
11034 const size_t x_sz = dbl_buf_size * 2;
11035
11036 if (ctx->prealloc_size_x < x_sz) {
11037 ctx->prealloc_size_x = x_sz;
11038 ggml_vk_preallocate_buffers(ctx, subctx);
11039 }
11040 }
11041
11042 vk_subbuffer src_buf;
11043 vk_subbuffer dst_buf;
11044
11045 if (num_elements == ncols) {
11046 pc2.first_pass = 1;
11047 src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
11048 } else {
11049 src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
11050 }
11051 if (num_dst_elements == k) {
11052 pc2.last_pass = 1;
11053 dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
11054 } else {
11055 dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
11056 }
11057
11058 elements[0] = num_elements;
11059
11060 ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
11061 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
11062 num_elements = num_dst_elements;
11063 dbl_buf_index ^= 1;
11064 if (num_elements > k) {
11065 ggml_vk_sync_buffers(ctx, subctx);
11066 }
11067 done_one_iter = true;
11068 }
11069 ctx->prealloc_x_need_sync = true;
11070}
11071
11072static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11073 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
11074 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
11075}
11076
11077static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11078 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
11079 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p);
11080}
11081
11082static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11083 vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
11084 p.weight = 1.0f / (float)src0->ne[0];
11085 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
11086}
11087
11088static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11089 vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
11090 // Use the single pass shader when the rows are small or there are enough rows to fill the GPU.
11091 // For fewer, larger rows, use the multipass shader to spread each row across SMs.
11092 if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
11093 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
11094 return;
11095 }
11096
11097 // First pass computes partial sums within a block, and stores the last partial
11098 // to the temp buffer. Second pass sums the block partials from the temp buffer
11099 // and adds that to the result of the first pass.
11100 vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
11101 vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
11102 GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
11103
11104 ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
11105 ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
11106
11107 std::array<uint32_t, 3> elements;
11108
11109 elements[0] = dst->ne[0];
11110 elements[1] = (uint32_t)ggml_nrows(dst);
11111 elements[2] = 1;
11112
11113 size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
11114
11115 if (ctx->prealloc_size_split_k < temp_size) {
11116 ctx->prealloc_size_split_k = temp_size;
11117 ggml_vk_preallocate_buffers(ctx, subctx);
11118 }
11119
11120 vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
11121 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
11122 vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
11123
11124 if (ctx->prealloc_split_k_need_sync) {
11125 ggml_vk_sync_buffers(ctx, subctx);
11126 }
11127
11128 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
11129 ggml_vk_sync_buffers(ctx, subctx);
11130 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
11131
11132 ctx->prealloc_split_k_need_sync = true;
11133}
11134
11135static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11136 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f, 0.0f, 0.0f });
11137}
11138
11139static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11140 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f, 0.0f, 0.0f });
11141}
11142
11143static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11144 const uint32_t src0_type_size = ggml_type_size(src0->type);
11145 const uint32_t src1_type_size = ggml_type_size(src1->type);
11146 const uint32_t dst_type_size = ggml_type_size(dst->type);
11147
11148 ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
11149 (uint32_t)ggml_nelements(src0),
11150 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
11151 (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
11152 (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
11153 0,
11154 0.0f, 0.0f, 0,
11155 });
11156}
11157
11158static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11159 const int32_t s0 = dst->op_params[0];
11160 const int32_t s1 = dst->op_params[1];
11161 const int32_t p0 = dst->op_params[2];
11162 const int32_t p1 = dst->op_params[3];
11163 const int32_t d0 = dst->op_params[4];
11164 const int32_t d1 = dst->op_params[5];
11165
11166 const bool is_2D = dst->op_params[6] == 1;
11167
11168 const uint32_t IC = src1->ne[is_2D ? 2 : 1];
11169 const uint32_t IH = is_2D ? src1->ne[1] : 1;
11170 const uint32_t IW = src1->ne[0];
11171
11172 const uint32_t KH = is_2D ? src0->ne[1] : 1;
11173 const uint32_t KW = src0->ne[0];
11174
11175 const uint32_t OH = is_2D ? dst->ne[2] : 1;
11176 const uint32_t OW = dst->ne[1];
11177
11178 const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
11179 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
11180
11181 const uint32_t pelements = OW * KW * KH;
11182 const uint32_t batch = src1->ne[is_2D ? 3 : 2];
11183
11184 const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
11185 const vk_buffer d_buf = d_buf_ctx->dev_buffer;
11186
11187 const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
11188
11189 ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL, {
11190 dst_addr,
11191 batch_offset, offset_delta,
11192 IC, IW, IH, OW, OH, KW, KH,
11193 pelements,
11194 IC * KH * KW,
11195 s0, s1, p0, p1, d0, d1, batch * IC
11196 });
11197}
11198
11199static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11200 GGML_TENSOR_BINARY_OP_LOCALS
11201
11202 const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
11203 const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
11204 const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
11205 const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
11206 const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
11207 const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
11208 const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
11209 const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
11210 const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
11211 const int32_t IC = ((const int32_t *)(dst->op_params))[9];
11212
11213 const int64_t N = ne13 / IC;
11214 const int64_t ID = ne12;
11215 const int64_t IH = ne11;
11216 const int64_t IW = ne10;
11217
11218 const int64_t KD = ne02;
11219 const int64_t KH = ne01;
11220 const int64_t KW = ne00;
11221
11222 const int64_t OD = ne3 / N;
11223 const int64_t OH = ne2;
11224 const int64_t OW = ne1;
11225
11226 const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
11227 const vk_buffer d_buf = d_buf_ctx->dev_buffer;
11228
11229 const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
11230
11231 vk_op_im2col_3d_push_constants pc {};
11232
11233 pc.dst_addr = dst_addr;
11234 pc.nb10 = nb10 / ggml_type_size(src1->type);
11235 pc.nb11 = nb11 / ggml_type_size(src1->type);
11236 pc.nb12 = nb12 / ggml_type_size(src1->type);
11237 pc.nb13 = nb13 / ggml_type_size(src1->type);
11238 pc.s0 = s0;
11239 pc.s1 = s1;
11240 pc.s2 = s2;
11241 pc.p0 = p0;
11242 pc.p1 = p1;
11243 pc.p2 = p2;
11244 pc.d0 = d0;
11245 pc.d1 = d1;
11246 pc.d2 = d2;
11247 pc.IW = IW;
11248 pc.IH = IH;
11249 pc.ID = ID;
11250 pc.IC = IC;
11251 pc.KW = KW;
11252 pc.OH = OH;
11253 pc.KD_KH_KW = KD*KH*KW;
11254 pc.KH_KW = KH*KW;
11255 pc.IC_KD_KH_KW = IC*KD*KH*KW;
11256 pc.N_OD_OH = N*OD*OH;
11257 pc.OD_OH = OD*OH;
11258 pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
11259 pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
11260 pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
11261
11262 ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc));
11263}
11264
11265static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11266 const uint32_t dim = dst->op_params[0];
11267 const uint32_t max_period = dst->op_params[1];
11268 const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type);
11269
11270 ggml_vk_op_f32<vk_op_timestep_embedding_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, {
11271 nb1, dim, max_period,
11272 });
11273}
11274
11275static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11276 // src0: (K, Cout, Cin, 1) -- kernel
11277 // src1: (L, Cin, 1, 1) -- input
11278 // dst: (*, Cout, 1, 1)
11279
11280 GGML_ASSERT(src0->type == GGML_TYPE_F32);
11281 GGML_ASSERT(src1->type == GGML_TYPE_F32);
11282 GGML_ASSERT( dst->type == GGML_TYPE_F32);
11283
11284 GGML_TENSOR_BINARY_OP_LOCALS
11285
11286 GGML_ASSERT(nb00 == sizeof(float));
11287 GGML_ASSERT(nb10 == sizeof(float));
11288
11289 const int32_t s0 = dst->op_params[0];
11290
11291 vk_op_conv_transpose_1d_push_constants p{};
11292 p.Cout = static_cast<uint32_t>(ne01);
11293 p.Cin = static_cast<uint32_t>(ne02);
11294 p.K = static_cast<uint32_t>(ne00);
11295 p.L = static_cast<uint32_t>(ne10);
11296 p.KL = static_cast<uint32_t>(ne0);
11297 p.nb01 = static_cast<uint32_t>(nb01 / nb00);
11298 p.nb02 = static_cast<uint32_t>(nb02 / nb00);
11299 p.nb11 = static_cast<uint32_t>(nb11 / nb10);
11300 p.nb1 = static_cast<uint32_t>(nb1 / nb0);
11301 p.s0 = static_cast<uint32_t>(s0);
11302
11303 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
11304}
11305
11306static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11307 uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
11308 const int32_t k1 = dst->op_params[1];
11309 const int32_t k0 = dst->op_params[2];
11310 const int32_t s1 = dst->op_params[3];
11311 const int32_t s0 = dst->op_params[4];
11312 const int32_t p1 = dst->op_params[5];
11313 const int32_t p0 = dst->op_params[6];
11314
11315 const uint32_t IH = src0->ne[1];
11316 const uint32_t IW = src0->ne[0];
11317
11318 const uint32_t N = dst->ne[3];
11319
11320 const uint32_t OC = dst->ne[2];
11321 const uint32_t OH = dst->ne[1];
11322 const uint32_t OW = dst->ne[0];
11323
11324 const uint32_t parallel_elements = N * OC * OH * OW;
11325
11326 ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
11327 IW, IH, OW, OH, OC,
11328 parallel_elements,
11329 op,
11330 k0, k1, s0, s1, p0, p1,
11331 });
11332}
11333
11334static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0,
11335 const ggml_tensor * src1, ggml_tensor * dst) {
11336 GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
11337 GGML_ASSERT(src1->type == GGML_TYPE_F32);
11338 GGML_ASSERT(dst->type == GGML_TYPE_F32);
11339
11340 GGML_TENSOR_BINARY_OP_LOCALS
11341 GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t));
11342 GGML_ASSERT(nb10 == sizeof(float));
11343 GGML_ASSERT(nb0 == sizeof(float));
11344
11345 bool transpose = dst->op == GGML_OP_CONV_TRANSPOSE_2D;
11346
11347 vk_op_conv2d_push_constants p{};
11348 p.Cout = static_cast<uint32_t>(!transpose ? ne03 : ne02);
11349 p.Cin = static_cast<uint32_t>(!transpose ? ne02 : ne03);
11350 p.N = static_cast<uint32_t>(ne13);
11351 GGML_ASSERT(p.Cout == ne2);
11352 GGML_ASSERT(p.Cin == ne12);
11353
11354 p.W = static_cast<uint32_t>(ne10);
11355 p.H = static_cast<uint32_t>(ne11);
11356 p.OW = static_cast<uint32_t>(ne0);
11357 p.OH = static_cast<uint32_t>(ne1);
11358
11359 p.nb01 = static_cast<uint32_t>(nb01 / nb00);
11360 p.nb02 = static_cast<uint32_t>(nb02 / nb00);
11361 p.nb03 = static_cast<uint32_t>(nb03 / nb00);
11362
11363 p.nb11 = static_cast<uint32_t>(nb11 / nb10);
11364 p.nb12 = static_cast<uint32_t>(nb12 / nb10);
11365 p.nb13 = static_cast<uint32_t>(nb13 / nb10);
11366
11367 p.nb1 = static_cast<uint32_t>(nb1 / nb0);
11368 p.nb2 = static_cast<uint32_t>(nb2 / nb0);
11369 p.nb3 = static_cast<uint32_t>(nb3 / nb0);
11370
11371 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p));
11372}
11373
11374static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
11375 vk_op_conv2d_dw_push_constants p{};
11376 p.ne = ggml_nelements(dst);
11377 p.channels = dst->ne[2];
11378 p.batches = dst->ne[3];
11379 p.dst_w = dst->ne[0];
11380 p.dst_h = dst->ne[1];
11381 p.src_w = src1->ne[0];
11382 p.src_h = src1->ne[1];
11383 p.knl_w = src0->ne[0];
11384 p.knl_h = src0->ne[1];
11385 p.stride_x = dst->op_params[0];
11386 p.stride_y = dst->op_params[1];
11387 p.pad_x = dst->op_params[2];
11388 p.pad_y = dst->op_params[3];
11389 p.dilation_x = dst->op_params[4];
11390 p.dilation_y = dst->op_params[5];
11391
11392 GGML_ASSERT(src0->ne[3] == p.channels);
11393 GGML_ASSERT(src1->ne[3] == p.batches);
11394
11395 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p));
11396}
11397
11398static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11399 const float * op_params = (const float *)dst->op_params;
11400 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f });
11401}
11402
11403#ifdef GGML_VULKAN_RUN_TESTS
11404static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) {
11405 if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) {
11406 return;
11407 }
11408 i0 = std::max(i0, 5);
11409 i1 = std::max(i1, 5);
11410 i2 = std::max(i2, 0);
11411 fprintf(stderr, " ");
11412 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11413 fprintf(stderr, "%7d ", idx1);
11414 }
11415 fprintf(stderr, "\n");
11416 for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
11417 fprintf(stderr, "%7d: ", idx0);
11418 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11419 if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) {
11420 float val;
11421 if (type == GGML_TYPE_F32) {
11422 val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0);
11423 } else if (type == GGML_TYPE_F16) {
11424 val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0));
11425 } else {
11426 GGML_ABORT("fatal error");
11427 }
11428 fprintf(stderr, "% 7.2f ", val);
11429 } else {
11430 fprintf(stderr, " ");
11431 }
11432 }
11433 fprintf(stderr, "\n");
11434 }
11435}
11436
11437template <typename X_TYPE, typename Y_TYPE>
11438static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) {
11439 VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")");
11440 const size_t x_ne = m * k * batch;
11441 const size_t y_ne = k * n * batch;
11442 const size_t d_ne = m * n * batch;
11443
11444 vk_pipeline p;
11445 std::string shname;
11446 if (shader_size == 0) {
11447 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11448 p = ctx->device->pipeline_matmul_f32->a_s;
11449 shname = "F32_ALIGNED_S";
11450 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11451 p = ctx->device->pipeline_matmul_f32_f16->a_s;
11452 shname = "F32_F16_ALIGNED_S";
11453 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11454 p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s;
11455 shname = "F16_F32_ALIGNED_S";
11456 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11457 p = ctx->device->pipeline_matmul_f16.f32acc->a_s;
11458 shname = "F16_ALIGNED_S";
11459 } else {
11460 GGML_ABORT("fatal error");
11461 }
11462 } else if (shader_size == 1) {
11463 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11464 p = ctx->device->pipeline_matmul_f32->a_m;
11465 shname = "F32_ALIGNED_M";
11466 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11467 p = ctx->device->pipeline_matmul_f32_f16->a_m;
11468 shname = "F32_F16_ALIGNED_M";
11469 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11470 p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m;
11471 shname = "F16_F32_ALIGNED_M";
11472 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11473 p = ctx->device->pipeline_matmul_f16.f32acc->a_m;
11474 shname = "F16_ALIGNED_M";
11475 } else {
11476 GGML_ABORT("fatal error");
11477 }
11478 } else if (shader_size == 2) {
11479 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11480 p = ctx->device->pipeline_matmul_f32->a_l;
11481 shname = "F32_ALIGNED_L";
11482 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11483 p = ctx->device->pipeline_matmul_f32_f16->a_l;
11484 shname = "F32_F16_ALIGNED_L";
11485 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11486 p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l;
11487 shname = "F16_F32_ALIGNED_L";
11488 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11489 p = ctx->device->pipeline_matmul_f16.f32acc->a_l;
11490 shname = "F16_ALIGNED_L";
11491 } else {
11492 GGML_ABORT("fatal error");
11493 }
11494 } else {
11495 GGML_ASSERT(0);
11496 }
11497
11498 const size_t kpad = ggml_vk_align_size(k, p->align);
11499
11500 if (k != kpad) {
11501 if (shader_size == 0) {
11502 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11503 p = ctx->device->pipeline_matmul_f32->s;
11504 shname = "F32_S";
11505 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11506 p = ctx->device->pipeline_matmul_f32_f16->s;
11507 shname = "F32_F16_S";
11508 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11509 p = ctx->device->pipeline_matmul_f16_f32.f32acc->s;
11510 shname = "F16_F32_S";
11511 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11512 p = ctx->device->pipeline_matmul_f16.f32acc->s;
11513 shname = "F16_S";
11514 }
11515 } else if (shader_size == 1) {
11516 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11517 p = ctx->device->pipeline_matmul_f32->m;
11518 shname = "F32_M";
11519 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11520 p = ctx->device->pipeline_matmul_f32_f16->m;
11521 shname = "F32_F16_M";
11522 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11523 p = ctx->device->pipeline_matmul_f16_f32.f32acc->m;
11524 shname = "F16_F32_M";
11525 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11526 p = ctx->device->pipeline_matmul_f16.f32acc->m;
11527 shname = "F16_M";
11528 }
11529 } else if (shader_size == 2) {
11530 if (std::is_same<float, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11531 p = ctx->device->pipeline_matmul_f32->l;
11532 shname = "F32_L";
11533 } else if (std::is_same<float, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11534 p = ctx->device->pipeline_matmul_f32_f16->l;
11535 shname = "F32_F16_L";
11536 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
11537 p = ctx->device->pipeline_matmul_f16_f32.f32acc->l;
11538 shname = "F16_F32_L";
11539 } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
11540 p = ctx->device->pipeline_matmul_f16.f32acc->l;
11541 shname = "F16_L";
11542 }
11543 }
11544 }
11545
11546 ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
11547 if (split_k > 1) {
11548 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
11549
11550 if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
11551 // Resize buffer
11552 if (ctx->prealloc_split_k != nullptr) {
11553 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
11554 }
11555 ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11556 }
11557 }
11558
11559 ggml_pipeline_allocate_descriptor_sets(ctx);
11560
11561 vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11562 vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11563 vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11564
11565 X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
11566 Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
11567 float* d = (float *) malloc(sizeof(float) * d_ne);
11568
11569 for (size_t i = 0; i < x_ne; i++) {
11570 if (std::is_same<float, X_TYPE>()) {
11571 x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
11572 // x[i] = 1.0f;
11573 // x[i] = i + 1;
11574 // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
11575 } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
11576 x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
11577 // x[i] = ggml_fp32_to_fp16(1.0f);
11578 // x[i] = ggml_fp32_to_fp16(i + 1);
11579 // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
11580 } else {
11581 GGML_ABORT("fatal error");
11582 }
11583 }
11584 for (size_t i = 0; i < y_ne; i++) {
11585 if (std::is_same<float, Y_TYPE>()) {
11586 y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
11587 // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
11588 // y[i] = i + 1;
11589 } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
11590 y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
11591 // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
11592 // y[i] = ggml_fp32_to_fp16(i + 1);
11593 } else {
11594 GGML_ABORT("fatal error");
11595 }
11596 }
11597
11598 ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch);
11599 ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
11600
11601 vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11602 ggml_vk_ctx_begin(ctx->device, subctx);
11603 for (size_t i = 0; i < num_it; i++) {
11604 ggml_vk_matmul(
11605 ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k),
11606 m, n, k,
11607 k, k, m, k*m, k*n, m*n,
11608 split_k, batch, batch, batch, 1, 1, n
11609 );
11610 }
11611 ggml_vk_ctx_end(subctx);
11612
11613 auto begin = std::chrono::high_resolution_clock::now();
11614 ggml_vk_submit(subctx, ctx->fence);
11615 VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences");
11616 ctx->device->device.resetFences({ ctx->fence });
11617 ggml_vk_queue_command_pools_cleanup(ctx->device);
11618
11619 auto end = std::chrono::high_resolution_clock::now();
11620 double time = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
11621
11622 // copy dst to host
11623 ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne);
11624
11625 float * d_chk = (float *) malloc(sizeof(float) * d_ne);
11626
11627 ggml_init_params iparams = {
11628 /*.mem_size =*/ 1024*1024*1024,
11629 /*.mem_buffer =*/ NULL,
11630 /*.no_alloc =*/ true,
11631 };
11632
11633 ggml_context * ggml_ctx = ggml_init(iparams);
11634
11635 ggml_type src0_type;
11636 ggml_type src1_type;
11637
11638 if (std::is_same<float, X_TYPE>()) {
11639 src0_type = GGML_TYPE_F32;
11640 } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
11641 src0_type = GGML_TYPE_F16;
11642 } else {
11643 GGML_ABORT("fatal error");
11644 }
11645 if (std::is_same<float, Y_TYPE>()) {
11646 src1_type = GGML_TYPE_F32;
11647 } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
11648 src1_type = GGML_TYPE_F16;
11649 } else {
11650 GGML_ABORT("fatal error");
11651 }
11652
11653 ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch);
11654 ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch);
11655 ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
11656
11657 src0_ggml->data = x;
11658 src1_ggml->data = y;
11659 tensor_ggml->data = d_chk;
11660
11661 ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
11662 ggml_build_forward_expand(cgraph, tensor_ggml);
11663
11664 ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
11665
11666 ggml_free(ggml_ctx);
11667
11668 double avg_err = 0.0;
11669 int first_err_n = -1;
11670 int first_err_m = -1;
11671 int first_err_b = -1;
11672
11673 for (size_t i = 0; i < m*n*batch; i++) {
11674 double err = std::fabs(d[i] - d_chk[i]);
11675 avg_err += err;
11676
11677 if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
11678 first_err_b = i / (m * n);
11679 first_err_n = (i % (m * n)) / m;
11680 first_err_m = (i % (m * n)) % m;
11681 }
11682 }
11683
11684 avg_err /= m * n;
11685
11686 double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
11687
11688 std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
11689
11690 if (avg_err > 0.1 || std::isnan(avg_err)) {
11691 std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
11692 std::cerr << "Actual result: " << std::endl << std::endl;
11693 ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11694 std::cerr << "Expected result: " << std::endl << std::endl;
11695 ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11696
11697 if (split_k > 1) {
11698 float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
11699 ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
11700
11701 std::cerr << "d_buf0: " << std::endl << std::endl;
11702 ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11703
11704 std::cerr << "d_buf1: " << std::endl << std::endl;
11705 ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11706
11707 std::cerr << "d_buf2: " << std::endl << std::endl;
11708 ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11709
11710 std::cerr << "d_buf3: " << std::endl << std::endl;
11711 ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
11712
11713 free(split_k_buf);
11714 }
11715 }
11716
11717 free(d_chk);
11718
11719 ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
11720
11721 ggml_vk_destroy_buffer(d_X);
11722 ggml_vk_destroy_buffer(d_Y);
11723 ggml_vk_destroy_buffer(d_D);
11724
11725 free(x);
11726 free(y);
11727 free(d);
11728}
11729
11730static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) {
11731 if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) {
11732 return;
11733 }
11734 i0 = std::max(i0, 5);
11735 i1 = std::max(i1, 5);
11736 i2 = std::max(i2, 0);
11737 i3 = std::max(i3, 0);
11738 fprintf(stderr, " ");
11739 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11740 fprintf(stderr, "%7d ", idx1);
11741 }
11742 fprintf(stderr, "\n");
11743 for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
11744 fprintf(stderr, "%7d: ", idx0);
11745 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
11746 if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
11747 float val;
11748 if (tensor->type == GGML_TYPE_F32) {
11749 val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
11750 } else if (tensor->type == GGML_TYPE_F16) {
11751 val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
11752 } else {
11753 GGML_ABORT("fatal error");
11754 }
11755 fprintf(stderr, "% 7.2f ", val);
11756 } else {
11757 fprintf(stderr, " ");
11758 }
11759 }
11760 fprintf(stderr, "\n");
11761 }
11762}
11763
11764static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) {
11765 ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr);
11766}
11767
11768static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) {
11769 if (quant == GGML_TYPE_F32) {
11770 memcpy(to, from, sizeof(float) * ne);
11771 return;
11772 }
11773
11774 const auto * tt = ggml_get_type_traits(quant);
11775
11776 ggml_to_float_t dequant_fn = tt->to_float;
11777
11778 dequant_fn(from, to, ne);
11779}
11780
11781static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
11782 VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")");
11783 const size_t x_sz = sizeof(float) * ne;
11784 const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne;
11785 const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
11786 float * x = (float *) malloc(x_sz);
11787 void * qx = malloc(qx_sz);
11788 vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11789 vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11790 float * x_ref = (float *) malloc(x_sz);
11791 ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
11792
11793 for (size_t i = 0; i < ne; i++) {
11794 x[i] = rand() / (float)RAND_MAX;
11795 }
11796
11797 vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant);
11798
11799 ggml_vk_quantize_data(x, qx, ne, quant);
11800 ggml_vk_dequantize_data(qx, x_ref, ne, quant);
11801
11802 ggml_pipeline_request_descriptor_sets(ctx, p, 1);
11803
11804 ggml_pipeline_allocate_descriptor_sets(ctx);
11805
11806 ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
11807
11808 vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11809 ggml_vk_ctx_begin(ctx->device, subctx);
11810 const std::vector<uint32_t> pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne };
11811 ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1});
11812 ggml_vk_ctx_end(subctx);
11813
11814 auto begin = std::chrono::high_resolution_clock::now();
11815
11816 ggml_vk_submit(subctx, ctx->fence);
11817 VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
11818 ctx->device->device.resetFences({ ctx->fence });
11819 ggml_vk_queue_command_pools_cleanup(ctx->device);
11820
11821 auto end = std::chrono::high_resolution_clock::now();
11822
11823 double ms_dequant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
11824 ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16);
11825
11826 int first_err = -1;
11827
11828 double avg_err = 0.0;
11829 for (size_t i = 0; i < ne; i++) {
11830 double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i]));
11831 avg_err += error;
11832
11833 if (first_err < 0 && error > 0.05) {
11834 first_err = i;
11835 }
11836 }
11837
11838 avg_err /= ne;
11839
11840 std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl;
11841
11842 if (avg_err > 0.1) {
11843 std::cerr << "first_error = " << first_err << std::endl;
11844 std::cerr << "Actual result: " << std::endl << std::endl;
11845 for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
11846 std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", ";
11847 }
11848 std::cerr << std::endl << "Expected result: " << std::endl << std::endl;
11849 for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) {
11850 std::cerr << x_ref[i] << ", ";
11851 }
11852 std::cerr << std::endl;
11853 }
11854
11855 ggml_vk_destroy_buffer(x_buf);
11856 ggml_vk_destroy_buffer(qx_buf);
11857
11858 free(x);
11859 free(qx);
11860 free(x_ref);
11861 free(x_chk);
11862}
11863
11864// This does not work without ggml q8_1 quantization support
11865//
11866// typedef uint16_t ggml_half;
11867// typedef uint32_t ggml_half2;
11868//
11869// #define QK8_1 32
11870// typedef struct {
11871// union {
11872// struct {
11873// ggml_half d; // delta
11874// ggml_half s; // d * sum(qs[i])
11875// } GGML_COMMON_AGGR_S;
11876// ggml_half2 ds;
11877// } GGML_COMMON_AGGR_U;
11878// int8_t qs[QK8_1]; // quants
11879// } block_q8_1;
11880//
11881// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
11882// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
11883// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
11884//
11885// const size_t x_sz = sizeof(float) * ne;
11886// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
11887// float * x = (float *) malloc(x_sz);
11888// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
11889// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
11890// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11891// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
11892//
11893// for (size_t i = 0; i < ne; i++) {
11894// x[i] = rand() / (float)RAND_MAX;
11895// }
11896//
11897// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
11898//
11899// ggml_pipeline_request_descriptor_sets(ctx, p, 1);
11900//
11901// ggml_pipeline_allocate_descriptor_sets(ctx);
11902//
11903// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
11904//
11905// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
11906// ggml_vk_ctx_begin(ctx->device, subctx);
11907// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne);
11908// ggml_vk_ctx_end(subctx);
11909//
11910// auto begin = std::chrono::high_resolution_clock::now();
11911//
11912// ggml_vk_submit(subctx, ctx->fence);
11913// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
11914// ctx->device->device.resetFences({ ctx->fence });
11915// ggml_vk_queue_command_pools_cleanup(ctx->device);
11916//
11917// auto end = std::chrono::high_resolution_clock::now();
11918//
11919// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
11920// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
11921//
11922// ggml_vk_quantize_data(x, qx_res, ne, quant);
11923//
11924// int first_err = -1;
11925//
11926// for (size_t i = 0; i < ne / 32; i++) {
11927// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
11928//
11929// if (first_err < 0 && error > 0.1) {
11930// first_err = i;
11931// }
11932//
11933// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
11934//
11935// if (first_err < 0 && error > 0.1) {
11936// first_err = i;
11937// }
11938//
11939// for (size_t j = 0; j < 32; j++) {
11940// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
11941//
11942// if (first_err < 0 && error > 1) {
11943// first_err = i;
11944// }
11945// }
11946// }
11947//
11948// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
11949//
11950// if (first_err != -1) {
11951// std::cerr << "first_error = " << first_err << std::endl;
11952// std::cerr << "Actual result: " << std::endl << std::endl;
11953// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
11954// for (size_t j = 0; j < 32; j++) {
11955// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
11956// }
11957// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
11958// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
11959// for (size_t j = 0; j < 32; j++) {
11960// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
11961// }
11962// std::cerr << std::endl;
11963// }
11964//
11965// ggml_vk_destroy_buffer(x_buf);
11966// ggml_vk_destroy_buffer(qx_buf);
11967//
11968// free(x);
11969// free(qx);
11970// free(qx_res);
11971// }
11972
11973static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
11974 VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
11975 const size_t x_ne = m * k * batch;
11976 const size_t y_ne = k * n * batch;
11977 const size_t d_ne = m * n * batch;
11978
11979 vk_matmul_pipeline2 * pipelines;
11980
11981 if (mmq) {
11982 pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
11983 } else {
11984 pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
11985 }
11986
11987 const bool fp16acc = ctx->device->fp16;
11988
11989 vk_pipeline p;
11990 std::string shname;
11991 if (shader_size == 0) {
11992 p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
11993 shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
11994 } else if (shader_size == 1) {
11995 p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
11996 shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
11997 } else if (shader_size == 2) {
11998 p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
11999 shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
12000 } else {
12001 GGML_ASSERT(0);
12002 }
12003
12004 const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
12005
12006 if (mmq || k != kpad) {
12007 if (shader_size == 0) {
12008 p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
12009 shname = std::string(ggml_type_name(quant)) + "_S";
12010 } else if (shader_size == 1) {
12011 p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
12012 shname = std::string(ggml_type_name(quant)) + "_M";
12013 } else if (shader_size == 2) {
12014 p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
12015 shname = std::string(ggml_type_name(quant)) + "_L";
12016 } else {
12017 GGML_ASSERT(0);
12018 }
12019 }
12020
12021 if (p == nullptr) {
12022 std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
12023 return;
12024 }
12025
12026 const size_t x_sz = sizeof(float) * x_ne;
12027 const size_t y_sz = sizeof(float) * y_ne;
12028 const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
12029 const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
12030 const size_t d_sz = sizeof(float) * d_ne;
12031 float * x = (float *) malloc(x_sz);
12032 float * y = (float *) malloc(y_sz);
12033 void * qx = malloc(qx_sz);
12034 vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12035 vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12036 vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12037 vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12038 float * d = (float *) malloc(d_sz);
12039 float * d_chk = (float *) malloc(d_sz);
12040
12041 for (size_t i = 0; i < x_ne; i++) {
12042 x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
12043 // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
12044 // x[i] = i % k;
12045 }
12046
12047 ggml_vk_quantize_data(x, qx, x_ne, quant);
12048
12049 for (size_t i = 0; i < y_ne; i++) {
12050 y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
12051 // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
12052 // y[i] = i % k;
12053 }
12054
12055 ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
12056 if (split_k > 1) {
12057 ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
12058
12059 if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) {
12060 // Resize buffer
12061 if (ctx->prealloc_split_k != nullptr) {
12062 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
12063 }
12064 ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal});
12065 }
12066 }
12067 if (mmq) {
12068 vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
12069 ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
12070 }
12071
12072 ggml_pipeline_allocate_descriptor_sets(ctx);
12073
12074 ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
12075 ggml_vk_buffer_write(y_buf, 0, y, y_sz);
12076
12077 vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
12078 ggml_vk_ctx_begin(ctx->device, subctx);
12079 if (mmq) {
12080 for (size_t i = 0; i < num_it; i++) {
12081 ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
12082 ggml_vk_matmul(
12083 ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
12084 m, n, k,
12085 k, k, m, k*m, k*n, m*n,
12086 split_k, batch, batch, batch, 1, 1, n
12087 );
12088 }
12089 } else {
12090 for (size_t i = 0; i < num_it; i++) {
12091 ggml_vk_matmul(
12092 ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
12093 m, n, k,
12094 k, k, m, k*m, k*n, m*n,
12095 split_k, batch, batch, batch, 1, 1, n
12096 );
12097 }
12098 }
12099 ggml_vk_ctx_end(subctx);
12100
12101 auto begin = std::chrono::high_resolution_clock::now();
12102
12103 ggml_vk_submit(subctx, ctx->fence);
12104 VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences");
12105 ctx->device->device.resetFences({ ctx->fence });
12106 ggml_vk_queue_command_pools_cleanup(ctx->device);
12107
12108 auto end = std::chrono::high_resolution_clock::now();
12109
12110 double time_ms = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
12111 ggml_vk_buffer_read(d_buf, 0, d, d_sz);
12112
12113 ggml_init_params iparams = {
12114 /*.mem_size =*/ 1024*1024*1024,
12115 /*.mem_buffer =*/ NULL,
12116 /*.no_alloc =*/ true,
12117 };
12118
12119 ggml_context * ggml_ctx = ggml_init(iparams);
12120
12121 ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch);
12122 ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch);
12123 ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml);
12124
12125 src0_ggml->data = qx;
12126 src1_ggml->data = y;
12127 tensor_ggml->data = d_chk;
12128
12129 ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
12130 ggml_build_forward_expand(cgraph, tensor_ggml);
12131
12132 ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1);
12133
12134 ggml_free(ggml_ctx);
12135
12136 double avg_err = 0.0;
12137 int first_err_n = -1;
12138 int first_err_m = -1;
12139 int first_err_b = -1;
12140
12141 for (size_t i = 0; i < m*n*batch; i++) {
12142 double err = std::fabs(d[i] - d_chk[i]);
12143 avg_err += err;
12144
12145 if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
12146 first_err_b = i / (m * n);
12147 first_err_n = (i % (m * n)) / m;
12148 first_err_m = (i % (m * n)) % m;
12149 }
12150 }
12151
12152 avg_err /= m * n;
12153
12154 double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
12155
12156 std::cerr << "TEST dequant matmul " << shname;
12157 if (mmq) {
12158 std::cerr << " mmq";
12159 }
12160 std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
12161
12162 if (avg_err > 0.01 || std::isnan(avg_err)) {
12163 std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
12164 std::cerr << "Actual result: " << std::endl << std::endl;
12165 ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12166 std::cerr << std::endl;
12167 std::cerr << "Expected result: " << std::endl << std::endl;
12168 ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12169
12170 std::cerr << "src0: " << std::endl << std::endl;
12171 ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
12172 std::cerr << std::endl;
12173 std::cerr << "src1: " << std::endl << std::endl;
12174 ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
12175
12176 if (split_k > 1) {
12177 float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
12178 ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
12179
12180 std::cerr << "d_buf0: " << std::endl << std::endl;
12181 ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12182
12183 std::cerr << "d_buf1: " << std::endl << std::endl;
12184 ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12185
12186 std::cerr << "d_buf2: " << std::endl << std::endl;
12187 ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12188
12189 std::cerr << "d_buf3: " << std::endl << std::endl;
12190 ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
12191
12192 free(split_k_buf);
12193 }
12194 }
12195
12196 ggml_vk_destroy_buffer(qx_buf);
12197 ggml_vk_destroy_buffer(y_buf);
12198 ggml_vk_destroy_buffer(qy_buf);
12199 ggml_vk_destroy_buffer(d_buf);
12200
12201 free(x);
12202 free(qx);
12203 free(y);
12204 free(d);
12205 free(d_chk);
12206}
12207#endif
12208
12209static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx) {
12210#if defined(GGML_VULKAN_RUN_TESTS)
12211 const std::vector<size_t> vals {
12212 512, 512, 128,
12213 128, 512, 512,
12214 4096, 512, 4096,
12215 11008, 512, 4096,
12216 4096, 512, 11008,
12217 32000, 512, 4096,
12218 8, 8, 8,
12219 100, 46, 576,
12220 623, 111, 128,
12221 100, 46, 558,
12222 512, 1, 256,
12223 128, 110, 622,
12224 511, 511, 127,
12225 511, 511, 7,
12226 511, 511, 17,
12227 49, 49, 128,
12228 128, 49, 49,
12229 4096, 49, 4096,
12230 };
12231 const size_t num_it = 100;
12232
12233 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
12234 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
12235 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
12236
12237 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
12238 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
12239 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
12240
12241 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
12242 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
12243 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
12244
12245 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
12246 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
12247 ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
12248
12249 abort();
12250
12251 for (size_t i = 0; i < vals.size(); i += 3) {
12252 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
12253 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
12254 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
12255 std::cerr << '\n';
12256 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
12257 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
12258 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
12259 std::cerr << '\n';
12260 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
12261 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
12262 ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
12263 std::cerr << '\n' << std::endl;
12264
12265 if (vals[i + 2] % 32 == 0) {
12266 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
12267 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
12268 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
12269 std::cerr << '\n';
12270 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
12271 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
12272 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
12273 std::cerr << '\n';
12274 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
12275 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
12276 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
12277 std::cerr << '\n' << std::endl;
12278 }
12279
12280 if (vals[i + 2] % 256 == 0) {
12281 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
12282 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
12283 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
12284 std::cerr << '\n';
12285 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
12286 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
12287 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
12288 std::cerr << '\n';
12289 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
12290 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
12291 ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
12292 std::cerr << '\n' << std::endl;
12293 }
12294 }
12295
12296 GGML_ABORT("fatal error");
12297#endif
12298
12299 if (subctx) {
12300 // Submit and wait for any pending work before reallocating the buffers
12301 ggml_vk_ctx_end(subctx);
12302 ggml_vk_submit(subctx, {});
12303 ctx->submit_pending = true;
12304 ggml_vk_synchronize(ctx);
12305 GGML_ASSERT(ctx->compute_ctx.expired());
12306 ggml_vk_ctx_begin(ctx->device, subctx);
12307 ctx->compute_ctx = subctx;
12308 }
12309
12310 if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
12311 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")");
12312 // Resize buffer
12313 if (ctx->prealloc_x != nullptr) {
12314 ggml_vk_destroy_buffer(ctx->prealloc_x);
12315 }
12316 ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x);
12317 }
12318 if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) {
12319 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")");
12320 // Resize buffer
12321 if (ctx->prealloc_y != nullptr) {
12322 ggml_vk_destroy_buffer(ctx->prealloc_y);
12323 }
12324 ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
12325 ctx->prealloc_y_last_tensor_used = nullptr;
12326 }
12327 if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
12328 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
12329 // Resize buffer
12330 if (ctx->prealloc_split_k != nullptr) {
12331 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
12332 }
12333 ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k);
12334 }
12335 if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) {
12336 VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")");
12337 // Resize buffer
12338 if (ctx->prealloc_add_rms_partials != nullptr) {
12339 ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
12340 }
12341 ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials);
12342 }
12343}
12344
12345static void ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool almost_ready);
12346
12347// Returns true if node has enqueued work into the queue, false otherwise
12348// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
12349static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool last_node, bool almost_ready, bool submit){
12350 ggml_tensor * node = cgraph->nodes[node_idx];
12351 if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
12352 return false;
12353 }
12354 if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
12355 return false;
12356 }
12357
12358 VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
12359 ctx->semaphore_idx = 0;
12360
12361 ggml_tensor * src0 = node->src[0];
12362 ggml_tensor * src1 = node->src[1];
12363 ggml_tensor * src2 = node->src[2];
12364 ggml_tensor * src3 = node->src[3];
12365
12366 if (node->op == GGML_OP_ADD) {
12367 int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops;
12368 if (next_node_idx < cgraph->n_nodes &&
12369 cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM &&
12370 cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] &&
12371 ggml_nrows(cgraph->nodes[next_node_idx]) == 1 &&
12372 ctx->device->add_rms_fusion) {
12373 uint32_t size = ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]);
12374 ctx->do_add_rms_partials_offset_calculation = true;
12375 if (ctx->prealloc_size_add_rms_partials_offset + size <= ctx->prealloc_size_add_rms_partials) {
12376 ctx->do_add_rms_partials = true;
12377 }
12378 }
12379 }
12380
12381 vk_context compute_ctx;
12382
12383 if (ctx->compute_ctx.expired()) {
12384 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
12385 ctx->compute_ctx = compute_ctx;
12386 ggml_vk_ctx_begin(ctx->device, compute_ctx);
12387 } else {
12388 compute_ctx = ctx->compute_ctx.lock();
12389 }
12390
12391 {
12392 // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
12393 // to synchronize them. This handles most "normal" synchronization when computing the graph, and when
12394 // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers
12395 // outside of this logic. When a node uses one of the prealloc buffers for something like
12396 // dequantization or split_k, additional synchronization is needed between those passes.
12397 bool need_sync = false;
12398
12399 // Check whether "node" requires synchronization. The node requires synchronization if it
12400 // overlaps in memory with another unsynchronized node and at least one of them is a write.
12401 // Destination nodes are checked against both the written/read lists. Source nodes are only
12402 // checked against the written list. Two nodes overlap in memory if they come from the same
12403 // buffer and the tensor or view ranges overlap.
12404 auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector<const ggml_tensor *> &unsynced_nodes) -> bool {
12405 if (unsynced_nodes.size() == 0) {
12406 return false;
12407 }
12408 auto n_base = vk_tensor_offset(node) + node->view_offs;
12409 auto n_size = ggml_nbytes(node);
12410 ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context;
12411 vk_buffer a_buf = a_buf_ctx->dev_buffer;
12412 for (auto &other : unsynced_nodes) {
12413 ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context;
12414 vk_buffer o_buf = o_buf_ctx->dev_buffer;
12415 if (a_buf == o_buf) {
12416 auto o_base = vk_tensor_offset(other) + other->view_offs;
12417 auto o_size = ggml_nbytes(other);
12418
12419 if ((o_base <= n_base && n_base < o_base + o_size) ||
12420 (n_base <= o_base && o_base < n_base + n_size)) {
12421 return true;
12422 }
12423 }
12424 }
12425 return false;
12426 };
12427
12428 // For all fused ops, check if the destination node or any of the source
12429 // nodes require synchronization.
12430 for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) {
12431 const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
12432 // If the node actually writes to memory, then check if it needs to sync
12433 if (ctx->fused_ops_write_mask & (1 << i)) {
12434 if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) {
12435 need_sync = true;
12436 break;
12437 }
12438 }
12439 for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
12440 if (!cur_node->src[j]) {
12441 continue;
12442 }
12443 if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) {
12444 need_sync = true;
12445 break;
12446 }
12447 }
12448 }
12449
12450 if (need_sync) {
12451 if (vk_enable_sync_logger) {
12452 std::cerr << "sync" << std::endl;
12453 }
12454 ctx->unsynced_nodes_written.clear();
12455 ctx->unsynced_nodes_read.clear();
12456 ggml_vk_sync_buffers(ctx, compute_ctx);
12457
12458 if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
12459 ctx->query_node_idx[ctx->query_idx] = node_idx;
12460 compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
12461 }
12462 }
12463 // Add all fused nodes to the unsynchronized lists.
12464 for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
12465 const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
12466 // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
12467 if (ctx->fused_ops_write_mask & (1 << i)) {
12468 ctx->unsynced_nodes_written.push_back(cur_node);
12469 }
12470 for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
12471 if (!cur_node->src[j]) {
12472 continue;
12473 }
12474 ctx->unsynced_nodes_read.push_back(cur_node->src[j]);
12475 }
12476 }
12477 }
12478 if (vk_enable_sync_logger) {
12479 for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
12480 auto *n = cgraph->nodes[node_idx + i];
12481 std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " << n->name;
12482 if (n->op == GGML_OP_GLU) {
12483 std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
12484 }
12485 if (n->op == GGML_OP_ROPE) {
12486 const int mode = ((const int32_t *) n->op_params)[2];
12487 std::cerr << " rope mode: " << mode;
12488 }
12489 std::cerr << std::endl;
12490 }
12491 }
12492
12493 switch (node->op) {
12494 case GGML_OP_REPEAT:
12495 ggml_vk_repeat(ctx, compute_ctx, src0, node);
12496
12497 break;
12498 case GGML_OP_REPEAT_BACK:
12499 ggml_vk_repeat_back(ctx, compute_ctx, src0, node);
12500
12501 break;
12502 case GGML_OP_ACC:
12503 ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
12504
12505 break;
12506 case GGML_OP_GET_ROWS:
12507 ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node);
12508
12509 break;
12510 case GGML_OP_ADD:
12511 if (ctx->num_additional_fused_ops) {
12512 ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx);
12513 } else {
12514 ggml_vk_add(ctx, compute_ctx, src0, src1, node);
12515 }
12516 break;
12517 case GGML_OP_SUB:
12518 ggml_vk_sub(ctx, compute_ctx, src0, src1, node);
12519
12520 break;
12521 case GGML_OP_MUL:
12522 ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
12523
12524 break;
12525 case GGML_OP_DIV:
12526 ggml_vk_div(ctx, compute_ctx, src0, src1, node);
12527
12528 break;
12529 case GGML_OP_ADD_ID:
12530 ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node);
12531
12532 break;
12533 case GGML_OP_CONCAT:
12534 ggml_vk_concat(ctx, compute_ctx, src0, src1, node);
12535
12536 break;
12537 case GGML_OP_UPSCALE:
12538 ggml_vk_upscale(ctx, compute_ctx, src0, node);
12539
12540 break;
12541 case GGML_OP_ADD1:
12542 ggml_vk_add1(ctx, compute_ctx, src0, src1, node);
12543
12544 break;
12545 case GGML_OP_ARANGE:
12546 ggml_vk_arange(ctx, compute_ctx, node);
12547
12548 break;
12549 case GGML_OP_FILL:
12550 ggml_vk_fill(ctx, compute_ctx, node);
12551
12552 break;
12553 case GGML_OP_SCALE:
12554 ggml_vk_scale(ctx, compute_ctx, src0, node);
12555
12556 break;
12557 case GGML_OP_SQR:
12558 ggml_vk_sqr(ctx, compute_ctx, src0, node);
12559
12560 break;
12561 case GGML_OP_SQRT:
12562 ggml_vk_sqrt(ctx, compute_ctx, src0, node);
12563
12564 break;
12565 case GGML_OP_SIN:
12566 ggml_vk_sin(ctx, compute_ctx, src0, node);
12567
12568 break;
12569 case GGML_OP_COS:
12570 ggml_vk_cos(ctx, compute_ctx, src0, node);
12571
12572 break;
12573 case GGML_OP_LOG:
12574 ggml_vk_log(ctx, compute_ctx, src0, node);
12575
12576 break;
12577 case GGML_OP_TRI:
12578 ggml_vk_tri(ctx, compute_ctx, src0, node);
12579
12580 break;
12581 case GGML_OP_DIAG:
12582 ggml_vk_diag(ctx, compute_ctx, src0, node);
12583
12584 break;
12585 case GGML_OP_CLAMP:
12586 ggml_vk_clamp(ctx, compute_ctx, src0, node);
12587
12588 break;
12589 case GGML_OP_PAD:
12590 ggml_vk_pad(ctx, compute_ctx, src0, node);
12591
12592 break;
12593 case GGML_OP_ROLL:
12594 ggml_vk_roll(ctx, compute_ctx, src0, node);
12595
12596 break;
12597 case GGML_OP_CPY:
12598 case GGML_OP_CONT:
12599 case GGML_OP_DUP:
12600 ggml_vk_cpy(ctx, compute_ctx, src0, node);
12601
12602 break;
12603 case GGML_OP_SET_ROWS:
12604 ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node);
12605
12606 break;
12607 case GGML_OP_SILU_BACK:
12608 ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node);
12609
12610 break;
12611 case GGML_OP_NORM:
12612 ggml_vk_norm(ctx, compute_ctx, src0, node);
12613
12614 break;
12615 case GGML_OP_GROUP_NORM:
12616 ggml_vk_group_norm(ctx, compute_ctx, src0, node);
12617
12618 break;
12619 case GGML_OP_RMS_NORM:
12620 ggml_vk_rms_norm(ctx, compute_ctx, cgraph, node_idx, (float *)node->op_params);
12621 break;
12622 case GGML_OP_RMS_NORM_BACK:
12623 ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node);
12624
12625 break;
12626 case GGML_OP_L2_NORM:
12627 ggml_vk_l2_norm(ctx, compute_ctx, src0, node);
12628
12629 break;
12630 case GGML_OP_UNARY:
12631 if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12632 ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12633 break;
12634 }
12635
12636 switch (ggml_get_unary_op(node)) {
12637 case GGML_UNARY_OP_EXP:
12638 case GGML_UNARY_OP_SILU:
12639 case GGML_UNARY_OP_GELU:
12640 case GGML_UNARY_OP_GELU_ERF:
12641 case GGML_UNARY_OP_GELU_QUICK:
12642 case GGML_UNARY_OP_RELU:
12643 case GGML_UNARY_OP_NEG:
12644 case GGML_UNARY_OP_TANH:
12645 case GGML_UNARY_OP_SIGMOID:
12646 case GGML_UNARY_OP_HARDSIGMOID:
12647 case GGML_UNARY_OP_HARDSWISH:
12648 case GGML_UNARY_OP_ABS:
12649 case GGML_UNARY_OP_SOFTPLUS:
12650 case GGML_UNARY_OP_STEP:
12651 case GGML_UNARY_OP_ROUND:
12652 case GGML_UNARY_OP_CEIL:
12653 case GGML_UNARY_OP_FLOOR:
12654 case GGML_UNARY_OP_TRUNC:
12655 ggml_vk_unary(ctx, compute_ctx, src0, node);
12656 break;
12657 case GGML_UNARY_OP_XIELU:
12658 ggml_vk_xielu(ctx, compute_ctx, src0, node);
12659 break;
12660 default:
12661 return false;
12662 }
12663 break;
12664 case GGML_OP_GLU:
12665 switch (ggml_get_glu_op(node)) {
12666 case GGML_GLU_OP_GEGLU:
12667 case GGML_GLU_OP_REGLU:
12668 case GGML_GLU_OP_SWIGLU:
12669 case GGML_GLU_OP_SWIGLU_OAI:
12670 case GGML_GLU_OP_GEGLU_ERF:
12671 case GGML_GLU_OP_GEGLU_QUICK:
12672 ggml_vk_glu(ctx, compute_ctx, src0, src1, node);
12673 break;
12674 default:
12675 return false;
12676 }
12677 break;
12678 case GGML_OP_DIAG_MASK_INF:
12679 ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node);
12680
12681 break;
12682 case GGML_OP_SOFT_MAX:
12683 if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12684 ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12685 } else {
12686 ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
12687 }
12688
12689 break;
12690 case GGML_OP_SOFT_MAX_BACK:
12691 ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node);
12692
12693 break;
12694 case GGML_OP_ROPE:
12695 ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, false);
12696
12697 break;
12698 case GGML_OP_ROPE_BACK:
12699 ggml_vk_rope(ctx, compute_ctx, cgraph, node_idx, true);
12700
12701 break;
12702 case GGML_OP_ARGSORT:
12703 if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
12704 ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
12705 } else {
12706 ggml_vk_argsort(ctx, compute_ctx, src0, node);
12707 }
12708
12709 break;
12710 case GGML_OP_TOP_K:
12711 ggml_vk_topk(ctx, compute_ctx, src0, node);
12712
12713 break;
12714 case GGML_OP_SUM:
12715 ggml_vk_sum(ctx, compute_ctx, src0, node);
12716
12717 break;
12718 case GGML_OP_SUM_ROWS:
12719 ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
12720
12721 break;
12722 case GGML_OP_CUMSUM:
12723 ggml_vk_cumsum(ctx, compute_ctx, src0, node);
12724
12725 break;
12726 case GGML_OP_MEAN:
12727 ggml_vk_mean(ctx, compute_ctx, src0, node);
12728
12729 break;
12730 case GGML_OP_ARGMAX:
12731 ggml_vk_argmax(ctx, compute_ctx, src0, node);
12732
12733 break;
12734 case GGML_OP_COUNT_EQUAL:
12735 ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
12736
12737 break;
12738 case GGML_OP_SOLVE_TRI:
12739 ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
12740
12741 break;
12742 case GGML_OP_IM2COL:
12743 ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
12744
12745 break;
12746 case GGML_OP_IM2COL_3D:
12747 ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node);
12748
12749 break;
12750 case GGML_OP_TIMESTEP_EMBEDDING:
12751 ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
12752
12753 break;
12754 case GGML_OP_CONV_TRANSPOSE_1D:
12755 ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
12756
12757 break;
12758 case GGML_OP_POOL_2D:
12759 ggml_vk_pool_2d(ctx, compute_ctx, src0, node);
12760
12761 break;
12762 case GGML_OP_CONV_2D:
12763 case GGML_OP_CONV_TRANSPOSE_2D:
12764 ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node);
12765
12766 break;
12767 case GGML_OP_CONV_2D_DW:
12768 ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node);
12769
12770 break;
12771 case GGML_OP_LEAKY_RELU:
12772 ggml_vk_leaky_relu(ctx, compute_ctx, src0, node);
12773
12774 break;
12775 case GGML_OP_MUL_MAT:
12776 ggml_vk_mul_mat(ctx, compute_ctx, cgraph, node_idx);
12777
12778 break;
12779 case GGML_OP_MUL_MAT_ID:
12780 ggml_vk_mul_mat_id(ctx, compute_ctx, cgraph, node_idx);
12781
12782 break;
12783
12784 case GGML_OP_FLASH_ATTN_EXT:
12785 ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node);
12786
12787 break;
12788
12789 case GGML_OP_RWKV_WKV6:
12790 ggml_vk_rwkv_wkv6(ctx, compute_ctx, node);
12791
12792 break;
12793
12794 case GGML_OP_RWKV_WKV7:
12795 ggml_vk_rwkv_wkv7(ctx, compute_ctx, node);
12796
12797 break;
12798
12799 case GGML_OP_SSM_SCAN:
12800 ggml_vk_ssm_scan(ctx, compute_ctx, node);
12801
12802 break;
12803
12804 case GGML_OP_SSM_CONV:
12805 ggml_vk_ssm_conv(ctx, compute_ctx, node);
12806
12807 break;
12808
12809 case GGML_OP_OPT_STEP_ADAMW:
12810 ggml_vk_opt_step_adamw(ctx, compute_ctx, node);
12811
12812 break;
12813
12814 case GGML_OP_OPT_STEP_SGD:
12815 ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node);
12816
12817 break;
12818 default:
12819 return false;
12820 }
12821
12822 ctx->tensor_ctxs[node_idx] = compute_ctx;
12823
12824#if defined(GGML_VULKAN_CHECK_RESULTS)
12825 // Force context reset on each node so that each tensor ends up in its own context
12826 // and can be run and compared to its CPU equivalent separately
12827 last_node = true;
12828#endif
12829
12830 if (submit || last_node) {
12831 ggml_vk_ctx_end(compute_ctx);
12832
12833 // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward
12834 if (last_node) {
12835 compute_ctx->exit_tensor_idx = node_idx_begin;
12836 }
12837 else {
12838 compute_ctx->exit_tensor_idx = -1;
12839 }
12840
12841 ctx->compute_ctx.reset();
12842
12843 ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, almost_ready);
12844 }
12845 return true;
12846}
12847
12848static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool almost_ready = false) {
12849 GGML_UNUSED(cgraph);
12850 GGML_UNUSED(tensor);
12851
12852 VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")");
12853
12854 vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock();
12855
12856 // Only run if ctx hasn't been submitted yet
12857 if (!subctx->seqs.empty()) {
12858#ifdef GGML_VULKAN_CHECK_RESULTS
12859 ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
12860#endif
12861
12862 // Do staging buffer copies
12863 for (auto& cpy : subctx->in_memcpys) {
12864 memcpy(cpy.dst, cpy.src, cpy.n);
12865 }
12866
12867 for (auto& mset : subctx->memsets) {
12868 memset(mset.dst, mset.val, mset.n);
12869 }
12870
12871 if (almost_ready && !ctx->almost_ready_fence_pending) {
12872 ggml_vk_submit(subctx, ctx->almost_ready_fence);
12873 ctx->almost_ready_fence_pending = true;
12874 } else {
12875 ggml_vk_submit(subctx, {});
12876 }
12877 ctx->submit_pending = true;
12878
12879#ifdef GGML_VULKAN_CHECK_RESULTS
12880 ggml_vk_synchronize(ctx);
12881 ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
12882#endif
12883 }
12884
12885 if (tensor_idx == subctx->exit_tensor_idx) {
12886 // Do staging buffer copies
12887 for (auto& cpy : subctx->out_memcpys) {
12888 memcpy(cpy.dst, cpy.src, cpy.n);
12889 }
12890 subctx->in_memcpys.clear();
12891 subctx->out_memcpys.clear();
12892 subctx->memsets.clear();
12893 }
12894}
12895
12896// Clean up after graph processing is done
12897static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
12898 VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
12899 ctx->prealloc_y_last_pipeline_used = {};
12900
12901 ctx->unsynced_nodes_written.clear();
12902 ctx->unsynced_nodes_read.clear();
12903 ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
12904
12905 ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
12906
12907 for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
12908 ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
12909 }
12910 ctx->gc.semaphores.clear();
12911
12912 for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) {
12913 ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s });
12914 }
12915 ctx->gc.tl_semaphores.clear();
12916 ctx->semaphore_idx = 0;
12917
12918 ctx->event_idx = 0;
12919
12920 for (auto& event : ctx->gc.events) {
12921 ctx->device->device.resetEvent(event);
12922 }
12923
12924 ctx->tensor_ctxs.clear();
12925 ctx->gc.contexts.clear();
12926 ctx->pipeline_descriptor_set_requirements = 0;
12927 ctx->descriptor_set_idx = 0;
12928}
12929
12930// Clean up on backend free
12931static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
12932 VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
12933 // discard any unsubmitted command buffers
12934 ctx->compute_ctx.reset();
12935 // wait for any pending command buffers to finish
12936 ggml_vk_synchronize(ctx);
12937
12938 ggml_vk_graph_cleanup(ctx);
12939
12940 ggml_vk_destroy_buffer(ctx->prealloc_x);
12941 ggml_vk_destroy_buffer(ctx->prealloc_y);
12942 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
12943 ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials);
12944 ggml_vk_destroy_buffer(ctx->sync_staging);
12945
12946 ctx->prealloc_y_last_pipeline_used = nullptr;
12947
12948 ctx->prealloc_size_x = 0;
12949 ctx->prealloc_size_y = 0;
12950 ctx->prealloc_size_split_k = 0;
12951
12952 for (auto& event : ctx->gc.events) {
12953 ctx->device->device.destroyEvent(event);
12954 }
12955 ctx->gc.events.clear();
12956
12957 ctx->device->device.destroyFence(ctx->fence);
12958 ctx->device->device.destroyFence(ctx->almost_ready_fence);
12959
12960 for (auto& pool : ctx->descriptor_pools) {
12961 ctx->device->device.destroyDescriptorPool(pool);
12962 }
12963 ctx->descriptor_pools.clear();
12964 ctx->descriptor_sets.clear();
12965
12966 ctx->compute_cmd_pool.destroy(ctx->device->device);
12967 if (vk_perf_logger_enabled) {
12968 ctx->perf_logger->print_timings(true);
12969 }
12970}
12971
12972static int ggml_vk_get_device_count() {
12973 ggml_vk_instance_init();
12974
12975 return vk_instance.device_indices.size();
12976}
12977
12978static void ggml_vk_get_device_description(int device, char * description, size_t description_size) {
12979 ggml_vk_instance_init();
12980
12981 std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
12982
12983 vk::PhysicalDeviceProperties props;
12984 devices[device].getProperties(&props);
12985
12986 snprintf(description, description_size, "%s", props.deviceName.data());
12987}
12988
12989// backend interface
12990
12991#define UNUSED GGML_UNUSED
12992
12993// device backend
12994
12995static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
12996 return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
12997}
12998
12999static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
13000 VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()");
13001 ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13002 ggml_vk_destroy_buffer(ctx->dev_buffer);
13003 delete ctx;
13004}
13005
13006static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
13007 return vk_ptr_base;
13008
13009 UNUSED(buffer);
13010}
13011
13012static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
13013 VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
13014 if (tensor->view_src != nullptr) {
13015 GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
13016 }
13017 return GGML_STATUS_SUCCESS;
13018}
13019
13020static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
13021 VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")");
13022 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13023 vk_buffer buf = buf_ctx->dev_buffer;
13024
13025 uint32_t val32 = (uint32_t)value * 0x01010101;
13026 ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
13027}
13028
13029static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
13030 VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
13031 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13032 vk_buffer buf = buf_ctx->dev_buffer;
13033
13034 ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
13035}
13036
13037static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
13038 VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
13039 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13040
13041 vk_buffer buf = buf_ctx->dev_buffer;
13042
13043 ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
13044}
13045
13046static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
13047 if (ggml_backend_buffer_is_vk(src->buffer)) {
13048 ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
13049 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
13050
13051 vk_buffer src_buf = src_buf_ctx->dev_buffer;
13052 vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
13053
13054 ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
13055
13056 return true;
13057 }
13058 return false;
13059
13060 UNUSED(buffer);
13061}
13062
13063static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
13064 ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
13065
13066 ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size);
13067}
13068
13069static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
13070 /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
13071 /* .get_base = */ ggml_backend_vk_buffer_get_base,
13072 /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
13073 /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
13074 /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
13075 /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
13076 /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
13077 /* .clear = */ ggml_backend_vk_buffer_clear,
13078 /* .reset = */ NULL,
13079};
13080
13081// vk buffer type
13082static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) {
13083 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
13084
13085 return ctx->name.c_str();
13086}
13087
13088static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
13089 VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")");
13090 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
13091
13092 vk_buffer dev_buffer = nullptr;
13093 try {
13094 dev_buffer = ggml_vk_create_buffer_device(ctx->device, size);
13095 } catch (const vk::SystemError& e) {
13096 return nullptr;
13097 }
13098
13099 ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name);
13100
13101 return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
13102}
13103
13104static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
13105 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
13106 return ctx->device->properties.limits.minStorageBufferOffsetAlignment;
13107}
13108
13109static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
13110 ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
13111 return ctx->device->suballocation_block_size;
13112}
13113
13114static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
13115 return ggml_nbytes(tensor);
13116
13117 UNUSED(buft);
13118}
13119
13120ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) {
13121 ggml_vk_instance_init();
13122
13123 VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")");
13124
13125 vk_device dev = ggml_vk_get_device(dev_num);
13126
13127 return &dev->buffer_type;
13128}
13129
13130// host buffer type
13131
13132static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
13133 return GGML_VK_NAME "_Host";
13134
13135 UNUSED(buft);
13136}
13137
13138static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
13139 return GGML_VK_NAME "_Host";
13140
13141 UNUSED(buffer);
13142}
13143
13144static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
13145 VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
13146 ggml_vk_host_free(vk_instance.devices[0], buffer->context);
13147}
13148
13149static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
13150 VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")");
13151
13152 size += 32; // Behave like the CPU buffer type
13153 void * ptr = nullptr;
13154 try {
13155 ptr = ggml_vk_host_malloc(vk_instance.devices[0], size);
13156 } catch (vk::SystemError& e) {
13157 GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what());
13158 // fallback to cpu buffer
13159 return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
13160 }
13161
13162 ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
13163 buffer->buft = buft;
13164 buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
13165
13166 return buffer;
13167
13168 UNUSED(buft);
13169}
13170
13171static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
13172 return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment;
13173
13174 UNUSED(buft);
13175}
13176
13177static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
13178 return vk_instance.devices[0]->suballocation_block_size;
13179
13180 UNUSED(buft);
13181}
13182
13183// Should be changed to return device-specific host buffer type
13184// but that probably requires changes in llama.cpp
13185ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
13186 static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = {
13187 /* .iface = */ {
13188 /* .get_name = */ ggml_backend_vk_host_buffer_type_name,
13189 /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer,
13190 /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment,
13191 /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size,
13192 /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
13193 /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
13194 },
13195 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
13196 /* .context = */ nullptr,
13197 };
13198
13199 // Make sure device 0 is initialized
13200 ggml_vk_instance_init();
13201 ggml_vk_get_device(0);
13202
13203 return &ggml_backend_vk_buffer_type_host;
13204}
13205
13206
13207// backend
13208
13209static const char * ggml_backend_vk_name(ggml_backend_t backend) {
13210 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13211
13212 return ctx->name.c_str();
13213}
13214
13215static void ggml_backend_vk_free(ggml_backend_t backend) {
13216 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13217 VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")");
13218
13219 ggml_vk_cleanup(ctx);
13220
13221 delete ctx;
13222 delete backend;
13223}
13224
13225static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) {
13226 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13227
13228 return &ctx->device->buffer_type;
13229}
13230
13231static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
13232 VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
13233 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13234 GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13235
13236 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13237
13238 vk_context compute_ctx;
13239
13240 if (ctx->compute_ctx.expired()) {
13241 // Initialize new transfer context
13242 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13243 ctx->compute_ctx = compute_ctx;
13244 ggml_vk_ctx_begin(ctx->device, compute_ctx);
13245 } else {
13246 compute_ctx = ctx->compute_ctx.lock();
13247 }
13248
13249 vk_buffer buf = buf_ctx->dev_buffer;
13250
13251 auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13252
13253 bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size);
13254
13255 if (!ret) {
13256 ggml_vk_ensure_sync_staging_buffer(ctx, size);
13257 ggml_vk_sync_buffers(nullptr, compute_ctx);
13258
13259 vk::BufferCopy buffer_cpy;
13260 buffer_cpy.srcOffset = 0;
13261 buffer_cpy.dstOffset = dst_offset;
13262 buffer_cpy.size = size;
13263
13264 compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
13265 deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys);
13266 ggml_vk_synchronize(ctx);
13267 }
13268}
13269
13270static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
13271 VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
13272 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13273 GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13274
13275 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13276
13277 vk_context compute_ctx;
13278
13279 if (ctx->compute_ctx.expired()) {
13280 // Initialize new transfer context
13281 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13282 ctx->compute_ctx = compute_ctx;
13283 ggml_vk_ctx_begin(ctx->device, compute_ctx);
13284 } else {
13285 compute_ctx = ctx->compute_ctx.lock();
13286 }
13287
13288 vk_buffer buf = buf_ctx->dev_buffer;
13289
13290 auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13291 bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
13292
13293 // If that failed, copy synchronously through a staging buffer
13294 if (!ret) {
13295 ggml_vk_ensure_sync_staging_buffer(ctx, size);
13296 ggml_vk_sync_buffers(nullptr, compute_ctx);
13297
13298 vk::BufferCopy buffer_cpy;
13299 buffer_cpy.srcOffset = src_offset;
13300 buffer_cpy.dstOffset = 0;
13301 buffer_cpy.size = size;
13302
13303 compute_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
13304 deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
13305 ggml_vk_synchronize(ctx);
13306 }
13307}
13308
13309static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
13310 VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
13311 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13312 if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
13313 ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
13314 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
13315
13316 vk_context compute_ctx;
13317
13318 if (ctx->compute_ctx.expired()) {
13319 // Initialize new transfer context
13320 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13321 ctx->compute_ctx = compute_ctx;
13322 ggml_vk_ctx_begin(ctx->device, compute_ctx);
13323 } else {
13324 compute_ctx = ctx->compute_ctx.lock();
13325 }
13326
13327 vk_buffer src_buf = src_buf_ctx->dev_buffer;
13328 vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
13329
13330 ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
13331 return true;
13332 }
13333
13334 return false;
13335}
13336
13337static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
13338 VK_LOG_DEBUG("ggml_vk_synchronize()");
13339
13340 bool do_transfer = !ctx->compute_ctx.expired();
13341
13342 vk_context compute_ctx;
13343 if (do_transfer) {
13344 compute_ctx = ctx->compute_ctx.lock();
13345
13346 ggml_vk_ctx_end(compute_ctx);
13347
13348 for (auto& cpy : compute_ctx->in_memcpys) {
13349 memcpy(cpy.dst, cpy.src, cpy.n);
13350 }
13351
13352 ggml_vk_submit(compute_ctx, {});
13353 ctx->submit_pending = true;
13354 }
13355
13356 if (ctx->submit_pending) {
13357 {
13358 std::lock_guard<std::mutex> guard(queue_mutex);
13359 ctx->device->compute_queue.queue.submit({}, ctx->fence);
13360 }
13361 ggml_vk_wait_for_fence(ctx);
13362 ctx->submit_pending = false;
13363 }
13364
13365 if (do_transfer) {
13366 for (auto& cpy : compute_ctx->out_memcpys) {
13367 memcpy(cpy.dst, cpy.src, cpy.n);
13368 }
13369 ctx->compute_ctx.reset();
13370 }
13371}
13372
13373static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
13374 VK_LOG_DEBUG("ggml_backend_vk_synchronize()");
13375 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13376
13377 ggml_vk_synchronize(ctx);
13378
13379 ggml_vk_graph_cleanup(ctx);
13380}
13381
13382static bool ggml_vk_is_empty(ggml_tensor * node) {
13383 return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
13384}
13385
13386static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
13387 if (!ggml_can_fuse(cgraph, node_idx, ops)) {
13388 return false;
13389 }
13390
13391 if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
13392 // additional constraints specific to this fusion
13393 const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
13394 const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
13395
13396 GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
13397 GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
13398 // rms_norm only supports f32
13399 if (mul->src[0]->type != GGML_TYPE_F32 ||
13400 mul->src[1]->type != GGML_TYPE_F32 ||
13401 mul->type != GGML_TYPE_F32) {
13402 return false;
13403 }
13404 // if rms_norm is the B operand, then we don't handle broadcast
13405 if (rms_norm == mul->src[1] &&
13406 !ggml_are_same_shape(mul->src[0], rms_norm)) {
13407 return false;
13408 }
13409 // rms_norm shader assumes contiguous rows
13410 if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
13411 return false;
13412 }
13413 }
13414 auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {
13415 const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
13416
13417 // mat-vec only
13418 if (ggml_nrows(mul) != 1) {
13419 return false;
13420 }
13421 // shaders assume the types match
13422 if (mul->type != bias->type) {
13423 return false;
13424 }
13425 // shaders reuse the D shape for bias
13426 if (!ggml_are_same_shape(mul, bias) ||
13427 !ggml_are_same_stride(mul, bias)) {
13428 return false;
13429 }
13430 // unaligned bias isn't handled
13431 if (get_misalign_bytes(ctx, bias) != 0) {
13432 return false;
13433 }
13434 return true;
13435 };
13436
13437 if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
13438 // additional constraints specific to this fusion
13439 const ggml_tensor *mul = cgraph->nodes[node_idx];
13440 const ggml_tensor *add = cgraph->nodes[node_idx + 1];
13441
13442 if (!mm_add_ok(mul, add)) {
13443 return false;
13444 }
13445 if (ops.size() == 3) {
13446 if (ops.begin()[2] != GGML_OP_ADD) {
13447 return false;
13448 }
13449 if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {
13450 return false;
13451 }
13452 }
13453 }
13454
13455 auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {
13456 const ggml_tensor *scale = mul->src[1];
13457
13458 if (mmid != mul->src[0]) {
13459 return false;
13460 }
13461 // mat-vec only
13462 if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
13463 return false;
13464 }
13465 // shaders assume the types match
13466 if (mmid->type != scale->type) {
13467 return false;
13468 }
13469 // shaders assume the bias is contiguous
13470 if (!ggml_is_contiguous(scale)) {
13471 return false;
13472 }
13473 // unaligned bias isn't handled
13474 if (get_misalign_bytes(ctx, scale) != 0) {
13475 return false;
13476 }
13477 // shader only indexes by expert index
13478 if (scale->ne[0] != 1 ||
13479 scale->ne[1] != mul->ne[1] ||
13480 scale->ne[2] != 1 ||
13481 scale->ne[3] != 1) {
13482 return false;
13483 }
13484 return true;
13485 };
13486
13487 if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
13488 // additional constraints specific to this fusion
13489 const ggml_tensor *mul = cgraph->nodes[node_idx];
13490 const ggml_tensor *add = cgraph->nodes[node_idx + 1];
13491 const ggml_tensor *bias = add->src[1];
13492
13493 if (mul != add->src[0]) {
13494 return false;
13495 }
13496 // mat-vec only
13497 if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
13498 return false;
13499 }
13500 // shaders assume the types match
13501 if (mul->type != bias->type) {
13502 return false;
13503 }
13504 // shaders assume the bias is contiguous
13505 if (!ggml_is_contiguous(bias)) {
13506 return false;
13507 }
13508 // the ID tensor must be the same for mul_mat_id and add_id
13509 if (mul->src[2] != add->src[2]) {
13510 return false;
13511 }
13512 // unaligned bias isn't handled
13513 if (get_misalign_bytes(ctx, bias) != 0) {
13514 return false;
13515 }
13516
13517 if (ops.size() == 3) {
13518 if (ops.begin()[2] != GGML_OP_MUL) {
13519 return false;
13520 }
13521 const ggml_tensor *mul = cgraph->nodes[node_idx + 2];
13522 return mmid_mul_ok(add, mul);
13523 }
13524 }
13525
13526 if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
13527 // additional constraints specific to this fusion
13528 const ggml_tensor *mmid = cgraph->nodes[node_idx];
13529 const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
13530
13531 if (!mmid_mul_ok(mmid, mul)) {
13532 return false;
13533 }
13534 }
13535
13536 return true;
13537}
13538
13539static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13540 int node_idx, topk_moe_mode mode) {
13541
13542 const ggml_tensor * softmax;
13543 const ggml_tensor * weights;
13544 const ggml_tensor * get_rows;
13545 const ggml_tensor * argsort;
13546
13547 switch (mode) {
13548 case TOPK_MOE_EARLY_SOFTMAX_NORM:
13549 softmax = cgraph->nodes[node_idx + 0];
13550 weights = cgraph->nodes[node_idx + 9];
13551 get_rows = cgraph->nodes[node_idx + 4];
13552 argsort = cgraph->nodes[node_idx + 2];
13553 break;
13554 case TOPK_MOE_SIGMOID_NORM_BIAS:
13555 softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
13556 weights = cgraph->nodes[node_idx + 10];
13557 get_rows = cgraph->nodes[node_idx + 5];
13558 argsort = cgraph->nodes[node_idx + 3];
13559 if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
13560 return false;
13561 }
13562 // bias is expected to be 1D
13563 if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
13564 !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
13565 return false;
13566 }
13567 // sigmoid fusion seems to generate infinities on moltenvk
13568 if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
13569 return false;
13570 }
13571 break;
13572 case TOPK_MOE_EARLY_SOFTMAX:
13573 softmax = cgraph->nodes[node_idx + 0];
13574 weights = cgraph->nodes[node_idx + 4];
13575 get_rows = cgraph->nodes[node_idx + 4];
13576 argsort = cgraph->nodes[node_idx + 2];
13577 break;
13578 case TOPK_MOE_LATE_SOFTMAX:
13579 softmax = cgraph->nodes[node_idx + 4];
13580 weights = cgraph->nodes[node_idx + 5];
13581 get_rows = cgraph->nodes[node_idx + 2];
13582 argsort = cgraph->nodes[node_idx + 0];
13583 break;
13584 default:
13585 return false;
13586 }
13587
13588 ggml_tensor * probs = get_rows->src[0];
13589 if (probs->op != GGML_OP_RESHAPE) {
13590 return false;
13591 }
13592 probs = probs->src[0];
13593 ggml_tensor * selection_probs = argsort->src[0];
13594
13595 if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
13596 return false;
13597 }
13598
13599 if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
13600 return false;
13601 }
13602
13603 if (softmax->op == GGML_OP_SOFT_MAX) {
13604 const float * op_params = (const float *)softmax->op_params;
13605
13606 float scale = op_params[0];
13607 float max_bias = op_params[1];
13608
13609 if (scale != 1.0f || max_bias != 0.0f) {
13610 return false;
13611 }
13612
13613 // don't fuse when masks or sinks are present
13614 if (softmax->src[1] || softmax->src[2]) {
13615 return false;
13616 }
13617 }
13618
13619 const int n_expert = softmax->ne[0];
13620 if (n_expert > (1 << (num_topk_moe_pipelines-1))) {
13621 return false;
13622 }
13623
13624 if (!ctx->device->subgroup_arithmetic ||
13625 !ctx->device->subgroup_shuffle ||
13626 !ctx->device->subgroup_require_full_support ||
13627 ctx->device->disable_fusion) {
13628 return false;
13629 }
13630
13631 return true;
13632}
13633
13634static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13635 int node_idx) {
13636 GGML_UNUSED(ctx);
13637 const ggml_tensor *rope = cgraph->nodes[node_idx + 0];
13638 const ggml_tensor *view = cgraph->nodes[node_idx + 1];
13639 const ggml_tensor *set_rows = cgraph->nodes[node_idx + 2];
13640
13641 // ne3 not tested
13642 if (rope->src[0]->ne[3] != 1) {
13643 return false;
13644 }
13645
13646 if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) {
13647 return false;
13648 }
13649
13650 if (set_rows->src[1]->type != GGML_TYPE_I64) {
13651 return false;
13652 }
13653
13654 // The view should flatten two dims of rope into one dim
13655 if (!ggml_is_contiguous(view) ||
13656 view->ne[0] != rope->ne[0] * rope->ne[1]) {
13657 return false;
13658 }
13659
13660 // Only norm/neox/mrope shaders have the fusion code
13661 const int mode = ((const int32_t *) rope->op_params)[2];
13662 if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_MROPE) {
13663 return false;
13664 }
13665
13666 return true;
13667}
13668
13669// Check whether the tensors overlap in memory but are not equal.
13670// Fusions can potenitally overwrite src tensors in ways that are not prevented
13671// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
13672// to overlap if they are exactly equal.
13673// XXX TODO this check is probably missing from several fusion optimizations.
13674static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
13675 ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
13676 vk_buffer a_buf = a_buf_ctx->dev_buffer;
13677 ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
13678 vk_buffer b_buf = b_buf_ctx->dev_buffer;
13679 if (a_buf == b_buf) {
13680 auto a_base = vk_tensor_offset(a) + a->view_offs;
13681 auto a_size = ggml_nbytes(a);
13682 auto b_base = vk_tensor_offset(b) + b->view_offs;
13683 auto b_size = ggml_nbytes(b);
13684
13685 if (a_base == b_base && a_size == b_size) {
13686 return false;
13687 }
13688
13689 if ((b_base <= a_base && a_base < b_base + b_size) ||
13690 (a_base <= b_base && b_base < a_base + a_size)) {
13691 return true;
13692 }
13693 }
13694 return false;
13695}
13696
13697static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13698 int node_idx) {
13699 GGML_UNUSED(ctx);
13700 const ggml_tensor *rms = cgraph->nodes[node_idx + 0];
13701 const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
13702 const ggml_tensor *rope = cgraph->nodes[node_idx + 2];
13703
13704 const int mode = ((const int32_t *) rope->op_params)[2];
13705
13706 // noncontig tensors aren't tested, and don't seem common in practice
13707 if (!ggml_is_contiguous(rms) ||
13708 !ggml_is_contiguous(mul) ||
13709 !ggml_is_contiguous(rope)) {
13710 return false;
13711 }
13712
13713 // only norm/neox are handled in the shader
13714 if (mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_NORMAL) {
13715 return false;
13716 }
13717
13718 // shared memory size for passing data from mul->rope
13719 if (mul->ne[0] > 1024) {
13720 return false;
13721 }
13722
13723 // must not overwrite srcs in a way that's not elementwise
13724 ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
13725 if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
13726 ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
13727 return false;
13728 }
13729
13730 // conditions for pipeline creation
13731 if (!(ctx->device->float_controls_rte_fp16 &&
13732 sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
13733 return false;
13734 }
13735
13736 return true;
13737}
13738
13739static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
13740
13741 const ggml_tensor *first_node = cgraph->nodes[node_idx];
13742 if (first_node->op != GGML_OP_ADD) {
13743 return 0;
13744 }
13745
13746 if (!ctx->device->multi_add) {
13747 return 0;
13748 }
13749
13750 int32_t num_adds = 1;
13751 while (node_idx + num_adds < cgraph->n_nodes &&
13752 cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD &&
13753 num_adds < MAX_FUSED_ADDS) {
13754 num_adds++;
13755 }
13756
13757 // The shader currently requires same shapes (but different strides are allowed),
13758 // everything f32, and no misalignment
13759 for (int32_t i = 0; i < num_adds; ++i) {
13760 const ggml_tensor *next_node = cgraph->nodes[node_idx + i];
13761 if (!ggml_are_same_shape(first_node, next_node->src[0]) ||
13762 !ggml_are_same_shape(first_node, next_node->src[1]) ||
13763 next_node->type != GGML_TYPE_F32 ||
13764 next_node->src[0]->type != GGML_TYPE_F32 ||
13765 next_node->src[1]->type != GGML_TYPE_F32 ||
13766 get_misalign_bytes(ctx, next_node) ||
13767 get_misalign_bytes(ctx, next_node->src[0]) ||
13768 get_misalign_bytes(ctx, next_node->src[1])) {
13769 num_adds = i;
13770 }
13771 }
13772
13773 // Verify we can fuse these
13774 ggml_op adds[MAX_FUSED_ADDS];
13775 for (int32_t i = 0; i < num_adds; ++i) {
13776 adds[i] = GGML_OP_ADD;
13777 }
13778
13779 // decrease num_adds if they can't all be fused
13780 while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) {
13781 num_adds--;
13782 }
13783
13784 // a single add is not "fused", so just return zero
13785 if (num_adds == 1) {
13786 return 0;
13787 }
13788 return num_adds;
13789}
13790
13791static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
13792 VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
13793 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13794
13795 if (vk_instance.debug_utils_support) {
13796 vk::DebugUtilsLabelEXT dul = {};
13797 dul.pLabelName = "ggml_backend_vk_graph_compute";
13798 dul.color = std::array<float,4>{1.0f, 1.0f, 1.0f, 1.0f};
13799 vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast<VkDebugUtilsLabelEXT*>(&dul));
13800 }
13801
13802 ctx->prealloc_size_add_rms_partials_offset = 0;
13803 ctx->do_add_rms_partials = false;
13804 ctx->do_add_rms_partials_offset_calculation = false;
13805
13806 int last_node = cgraph->n_nodes - 1;
13807
13808 // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
13809 while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
13810 last_node -= 1;
13811 }
13812
13813 // Reserve tensor context space for all nodes
13814 ctx->tensor_ctxs.resize(cgraph->n_nodes);
13815
13816 bool first_node_in_batch = true; // true if next node will be first node in a batch
13817 int submit_node_idx = 0; // index to first node in a batch
13818
13819 vk_context compute_ctx;
13820 if (vk_perf_logger_enabled) {
13821 // allocate/resize the query pool
13822 if (ctx->num_queries < cgraph->n_nodes + 1) {
13823 if (ctx->query_pool) {
13824 ctx->device->device.destroyQueryPool(ctx->query_pool);
13825 }
13826 vk::QueryPoolCreateInfo query_create_info;
13827 query_create_info.queryType = vk::QueryType::eTimestamp;
13828 query_create_info.queryCount = cgraph->n_nodes + 100;
13829 ctx->query_pool = ctx->device->device.createQueryPool(query_create_info);
13830 ctx->num_queries = query_create_info.queryCount;
13831 ctx->query_fusion_names.resize(ctx->num_queries);
13832 ctx->query_fusion_node_count.resize(ctx->num_queries);
13833 ctx->query_nodes.resize(ctx->num_queries);
13834 ctx->query_node_idx.resize(ctx->num_queries);
13835 }
13836
13837 ctx->device->device.resetQueryPool(ctx->query_pool, 0, cgraph->n_nodes+1);
13838 std::fill(ctx->query_fusion_names.begin(), ctx->query_fusion_names.end(), nullptr);
13839 std::fill(ctx->query_fusion_node_count.begin(), ctx->query_fusion_node_count.end(), 0);
13840 std::fill(ctx->query_nodes.begin(), ctx->query_nodes.end(), nullptr);
13841 std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
13842
13843 GGML_ASSERT(ctx->compute_ctx.expired());
13844 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13845 ctx->compute_ctx = compute_ctx;
13846 ggml_vk_ctx_begin(ctx->device, compute_ctx);
13847 ctx->query_idx = 0;
13848 compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
13849 }
13850
13851 ctx->prealloc_y_last_pipeline_used = nullptr;
13852 ctx->prealloc_y_last_tensor_used = nullptr;
13853
13854 if (ctx->prealloc_size_add_rms_partials) {
13855 ggml_vk_preallocate_buffers(ctx, nullptr);
13856 if (ctx->compute_ctx.expired()) {
13857 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13858 ctx->compute_ctx = compute_ctx;
13859 ggml_vk_ctx_begin(ctx->device, compute_ctx);
13860 } else {
13861 compute_ctx = ctx->compute_ctx.lock();
13862 }
13863 // initialize partial sums to zero.
13864 ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
13865 ggml_vk_sync_buffers(ctx, compute_ctx);
13866 }
13867
13868 // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
13869 // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
13870 // (and scaled down based on model size, so smaller models submit earlier).
13871 // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
13872 int nodes_per_submit = 100;
13873 int submitted_nodes = 0;
13874 int submit_count = 0;
13875 uint64_t mul_mat_bytes = 0;
13876 uint64_t total_mul_mat_bytes = 0;
13877 uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), ctx->last_total_mul_mat_bytes / 40u);
13878 for (int i = 0; i < cgraph->n_nodes; i++) {
13879 if (first_node_in_batch) {
13880 submit_node_idx = i;
13881 }
13882
13883 if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
13884 auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]);
13885 mul_mat_bytes += bytes;
13886 total_mul_mat_bytes += bytes;
13887 }
13888
13889 ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
13890 ctx->fused_topk_moe_scale = false;
13891 const char *fusion_string {};
13892 if (!ctx->device->disable_fusion) {
13893 uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
13894 if (num_adds) {
13895 ctx->num_additional_fused_ops = num_adds - 1;
13896 fusion_string = "MULTI_ADD";
13897 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
13898 ctx->num_additional_fused_ops = 2;
13899 fusion_string = "MUL_MAT_ADD_ADD";
13900 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
13901 ctx->num_additional_fused_ops = 1;
13902 fusion_string = "MUL_MAT_ADD";
13903 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
13904 ctx->num_additional_fused_ops = 2;
13905 fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
13906 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
13907 ctx->num_additional_fused_ops = 1;
13908 fusion_string = "MUL_MAT_ID_ADD_ID";
13909 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
13910 ctx->num_additional_fused_ops = 1;
13911 fusion_string = "MUL_MAT_ID_MUL";
13912 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
13913 ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
13914 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
13915 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
13916 ctx->num_additional_fused_ops = 4;
13917 fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
13918 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
13919 ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
13920 ctx->num_additional_fused_ops = 2;
13921 fusion_string = "RMS_NORM_MUL_ROPE";
13922 } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
13923 ctx->num_additional_fused_ops = 1;
13924 fusion_string = "RMS_NORM_MUL";
13925 } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
13926 ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
13927 ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
13928 ctx->num_additional_fused_ops = 2;
13929 fusion_string = "ROPE_VIEW_SET_ROWS";
13930 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
13931 ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
13932 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
13933 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
13934 // view of argsort writes to memory
13935 ctx->fused_ops_write_mask |= 1 << 3;
13936 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
13937 fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
13938 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
13939 ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
13940 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
13941 ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
13942 // view of argsort writes to memory
13943 ctx->fused_ops_write_mask |= 1 << 4;
13944 ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
13945 fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
13946 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
13947 ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
13948 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
13949 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
13950 // view of argsort writes to memory
13951 ctx->fused_ops_write_mask |= 1 << 3;
13952 ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
13953 fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
13954 } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
13955 ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
13956 ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
13957 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
13958 // view of argsort writes to memory
13959 ctx->fused_ops_write_mask |= 1 << 1;
13960 ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
13961 fusion_string = "TOPK_MOE_LATE_SOFTMAX";
13962 }
13963 if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
13964 // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
13965 if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
13966 ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
13967 ctx->fused_topk_moe_scale = true;
13968 ctx->num_additional_fused_ops++;
13969 }
13970 }
13971 }
13972 ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
13973
13974 // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
13975 bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
13976 bool submit = (submitted_nodes >= nodes_per_submit) ||
13977 (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) ||
13978 (i + ctx->num_additional_fused_ops >= last_node) ||
13979 (almost_ready && !ctx->almost_ready_fence_pending);
13980
13981 bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
13982
13983 if (vk_perf_logger_enabled && enqueued) {
13984 if (ctx->compute_ctx.expired()) {
13985 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13986 ctx->compute_ctx = compute_ctx;
13987 ggml_vk_ctx_begin(ctx->device, compute_ctx);
13988 } else {
13989 compute_ctx = ctx->compute_ctx.lock();
13990 }
13991 if (!vk_perf_logger_concurrent) {
13992 // track a single node/fusion for the current query
13993 ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
13994 ctx->query_fusion_names[ctx->query_idx] = fusion_string;
13995 compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
13996 } else {
13997 // track a fusion string and number of fused ops for the current node_idx
13998 ctx->query_fusion_names[i] = fusion_string;
13999 ctx->query_fusion_node_count[i] = ctx->num_additional_fused_ops;
14000 }
14001 }
14002
14003 if (enqueued) {
14004 ++submitted_nodes;
14005
14006#ifndef GGML_VULKAN_CHECK_RESULTS
14007 if (first_node_in_batch) {
14008 first_node_in_batch = false;
14009 }
14010#endif
14011 }
14012
14013 if (submit && enqueued) {
14014 first_node_in_batch = true;
14015 submitted_nodes = 0;
14016 mul_mat_bytes = 0;
14017 if (submit_count < 3) {
14018 mul_mat_bytes_per_submit *= 2;
14019 }
14020 submit_count++;
14021 }
14022 i += ctx->num_additional_fused_ops;
14023 ctx->num_additional_fused_ops = 0;
14024 ctx->fused_ops_write_mask = 0;
14025 }
14026
14027 ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
14028
14029 if (vk_perf_logger_enabled) {
14030 // End the command buffer and submit/wait
14031 GGML_ASSERT(!ctx->compute_ctx.expired());
14032 compute_ctx = ctx->compute_ctx.lock();
14033 ggml_vk_ctx_end(compute_ctx);
14034
14035 ggml_vk_submit(compute_ctx, ctx->device->fence);
14036 VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
14037 ctx->device->device.resetFences({ ctx->device->fence });
14038 ctx->compute_ctx.reset();
14039
14040 // Get the results and pass them to the logger
14041 std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
14042 VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->query_pool, 0, ctx->query_idx, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results");
14043 if (!vk_perf_logger_concurrent) {
14044 // Log each op separately
14045 for (int i = 1; i < ctx->query_idx; i++) {
14046 auto node = ctx->query_nodes[i];
14047 auto name = ctx->query_fusion_names[i];
14048 ctx->perf_logger->log_timing(node, name, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
14049 }
14050 } else {
14051 // Log each group of nodes
14052 int prev_node_idx = 0;
14053 for (int i = 1; i < ctx->query_idx; i++) {
14054 auto cur_node_idx = ctx->query_node_idx[i];
14055 std::vector<ggml_tensor *> nodes;
14056 std::vector<const char *> names;
14057 for (int node_idx = prev_node_idx; node_idx < cur_node_idx; ++node_idx) {
14058 if (ggml_op_is_empty(cgraph->nodes[node_idx]->op)) {
14059 continue;
14060 }
14061 nodes.push_back(cgraph->nodes[node_idx]);
14062 names.push_back(ctx->query_fusion_names[node_idx]);
14063 node_idx += ctx->query_fusion_node_count[node_idx];
14064 }
14065 prev_node_idx = cur_node_idx;
14066 ctx->perf_logger->log_timing(nodes, names, uint64_t((timestamps[i] - timestamps[i-1]) * ctx->device->properties.limits.timestampPeriod));
14067 }
14068 }
14069 ctx->perf_logger->print_timings();
14070 }
14071
14072 if (!ctx->device->support_async) {
14073 ggml_vk_synchronize(ctx);
14074 }
14075
14076 return GGML_STATUS_SUCCESS;
14077
14078 UNUSED(backend);
14079}
14080
14081// Sort the graph for improved parallelism.
14082static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph)
14083{
14084 VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)");
14085 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14086
14087 if (ctx->device->disable_graph_optimize) {
14088 return;
14089 }
14090
14091 auto const &is_empty = [](ggml_tensor * node) -> bool {
14092 return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
14093 };
14094
14095 auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
14096 for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
14097 if (dst->src[s] == src) {
14098 return true;
14099 }
14100 }
14101 // implicit dependency if they view the same tensor
14102 const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
14103 const ggml_tensor *src2 = src->view_src ? src->view_src : src;
14104 if (dst2 == src2) {
14105 return true;
14106 }
14107 return false;
14108 };
14109
14110 std::vector<ggml_tensor *> new_order;
14111 std::vector<bool> used(graph->n_nodes, false);
14112 std::set<ggml_tensor *> used_node_set;
14113
14114 int first_unused = 0;
14115 while (first_unused < graph->n_nodes) {
14116 std::vector<int> current_set;
14117
14118 // Check for fusion patterns and avoid reordering them
14119 auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
14120 if (start + (int)pattern.size() <= graph->n_nodes) {
14121 bool is_pattern = true;
14122 for (size_t j = 0; j < pattern.size(); ++j) {
14123 if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
14124 is_pattern = false;
14125 }
14126 }
14127 return is_pattern;
14128 }
14129 return false;
14130 };
14131
14132 auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
14133 if (match_pattern(pattern, first_unused)) {
14134 for (size_t j = 0; j < pattern.size(); ++j) {
14135 new_order.push_back(graph->nodes[first_unused + j]);
14136 used_node_set.insert(graph->nodes[first_unused + j]);
14137 used[first_unused + j] = true;
14138 }
14139 while (first_unused < graph->n_nodes && used[first_unused]) {
14140 first_unused++;
14141 }
14142 return true;
14143 }
14144 return false;
14145 };
14146
14147 if (keep_pattern(topk_moe_early_softmax_norm)) {
14148 continue;
14149 }
14150 if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
14151 continue;
14152 }
14153 if (keep_pattern(topk_moe_early_softmax)) {
14154 continue;
14155 }
14156 if (keep_pattern(topk_moe_late_softmax)) {
14157 continue;
14158 }
14159
14160 // First, grab the next unused node.
14161 current_set.push_back(first_unused);
14162
14163 // Loop through the next N nodes. Grab any that don't depend on other nodes that
14164 // haven't already been run. Nodes that have already been run have used[i] set
14165 // to true. Allow nodes that depend on the previous node if it's a fusion pattern
14166 // that we support (e.g. RMS_NORM + MUL).
14167 // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
14168 // The goal is to not interleave real and view nodes in a way that breaks fusion.
14169 const int NUM_TO_CHECK = 20;
14170 for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
14171 if (used[j]) {
14172 continue;
14173 }
14174 if (is_empty(graph->nodes[j])) {
14175 continue;
14176 }
14177 // Don't pull forward nodes from fusion patterns
14178 if (match_pattern(topk_moe_early_softmax_norm, j) ||
14179 match_pattern(topk_moe_sigmoid_norm_bias, j) ||
14180 match_pattern(topk_moe_early_softmax, j) ||
14181 match_pattern(topk_moe_late_softmax, j)) {
14182 continue;
14183 }
14184 bool ok = true;
14185 for (int c = first_unused; c < j; ++c) {
14186 if (!used[c] &&
14187 is_src_of(graph->nodes[j], graph->nodes[c]) &&
14188 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
14189 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
14190 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
14191 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
14192 !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
14193 ok = false;
14194 break;
14195 }
14196 }
14197 if (ok) {
14198 current_set.push_back(j);
14199
14200 int rope_idx = j;
14201
14202 // When we've found RMS_NORM + MUL, try to find a ROPE that uses it
14203 if (j > 0 &&
14204 graph->nodes[j]->op == GGML_OP_MUL &&
14205 graph->nodes[j-1]->op == GGML_OP_RMS_NORM) {
14206 for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
14207 if (graph->nodes[k]->op == GGML_OP_ROPE &&
14208 graph->nodes[k]->src[0] == graph->nodes[j] &&
14209 // Check that other srcs are already valid
14210 graph->nodes[k]->src[1]->op == GGML_OP_NONE &&
14211 (graph->nodes[k]->src[2] == nullptr || graph->nodes[k]->src[2]->op == GGML_OP_NONE)) {
14212 rope_idx = k;
14213 current_set.push_back(rope_idx);
14214 used[rope_idx] = true;
14215 break;
14216 }
14217 }
14218 }
14219 // Look for ROPE + VIEW + SET_ROWS and make them consecutive
14220 if (graph->nodes[rope_idx]->op == GGML_OP_ROPE) {
14221 int view_idx = -1;
14222 int set_rows_idx = -1;
14223 for (int k = rope_idx+1; k < std::min(rope_idx + 10, graph->n_nodes); ++k) {
14224 if (view_idx == -1 &&
14225 graph->nodes[k]->op == GGML_OP_VIEW &&
14226 graph->nodes[k]->src[0] == graph->nodes[rope_idx]) {
14227 view_idx = k;
14228 continue;
14229 }
14230 if (view_idx != -1 &&
14231 set_rows_idx == -1 &&
14232 graph->nodes[k]->op == GGML_OP_SET_ROWS &&
14233 graph->nodes[k]->src[0] == graph->nodes[view_idx]) {
14234 set_rows_idx = k;
14235 break;
14236 }
14237 }
14238 if (set_rows_idx != -1) {
14239 current_set.push_back(view_idx);
14240 current_set.push_back(set_rows_idx);
14241 used[view_idx] = true;
14242 used[set_rows_idx] = true;
14243 }
14244 }
14245 // Look for MUL_MAT_ID + ADD_ID + MUL
14246 if (j > 0 &&
14247 graph->nodes[j]->op == GGML_OP_ADD_ID &&
14248 graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {
14249 for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
14250 if (graph->nodes[k]->op == GGML_OP_MUL &&
14251 graph->nodes[k]->src[0] == graph->nodes[j] &&
14252 // src1 must either be weights or already processed
14253 (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
14254 current_set.push_back(k);
14255 used[k] = true;
14256 break;
14257 }
14258 }
14259 }
14260 // Look for MUL_MAT + ADD + ADD
14261 if (j > 0 &&
14262 graph->nodes[j]->op == GGML_OP_ADD &&
14263 graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {
14264 for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
14265 if (graph->nodes[k]->op == GGML_OP_ADD &&
14266 graph->nodes[k]->src[0] == graph->nodes[j] &&
14267 // src1 must either be weights or already processed
14268 (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
14269 current_set.push_back(k);
14270 used[k] = true;
14271 break;
14272 }
14273 }
14274 }
14275 }
14276 }
14277 // Second pass grabs view nodes.
14278 // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
14279 if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
14280 for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
14281 if (used[j]) {
14282 continue;
14283 }
14284 if (!is_empty(graph->nodes[j])) {
14285 continue;
14286 }
14287 bool ok = true;
14288 for (int c = first_unused; c < j; ++c) {
14289 bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
14290 // skip views whose srcs haven't been processed.
14291 if (!used[c] &&
14292 is_src_of(graph->nodes[j], graph->nodes[c]) &&
14293 !c_in_current_set) {
14294 ok = false;
14295 break;
14296 }
14297 }
14298 if (ok) {
14299 current_set.push_back(j);
14300 }
14301 }
14302 }
14303
14304 // Push the current set into new_order
14305 for (auto c : current_set) {
14306 new_order.push_back(graph->nodes[c]);
14307 used_node_set.insert(graph->nodes[c]);
14308 used[c] = true;
14309 }
14310 while (first_unused < graph->n_nodes && used[first_unused]) {
14311 first_unused++;
14312 }
14313 }
14314 // Replace the graph with the new order.
14315 for (int i = 0; i < graph->n_nodes; ++i) {
14316 graph->nodes[i] = new_order[i];
14317 }
14318}
14319
14320static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
14321 VK_LOG_DEBUG("ggml_backend_vk_event_record(backend=" << backend << ", event=" << event << ")");
14322 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14323 vk_event *vkev = (vk_event *)event->context;
14324
14325 vk_context compute_ctx;
14326
14327 if (ctx->compute_ctx.expired()) {
14328 // Initialize new transfer context
14329 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
14330 ctx->compute_ctx = compute_ctx;
14331 ggml_vk_ctx_begin(ctx->device, compute_ctx);
14332 } else {
14333 compute_ctx = ctx->compute_ctx.lock();
14334 }
14335
14336 // the backend interface doesn't have an explicit reset, so reset it here
14337 // before we record the command to set it
14338 ctx->device->device.resetEvent(vkev->event);
14339 ctx->device->device.resetFences({ vkev->fence });
14340
14341 ggml_vk_set_event(compute_ctx, vkev->event);
14342
14343 ggml_vk_ctx_end(compute_ctx);
14344
14345 ggml_vk_submit(compute_ctx, {vkev->fence});
14346 ctx->submit_pending = true;
14347 ctx->compute_ctx.reset();
14348}
14349
14350static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
14351 VK_LOG_DEBUG("ggml_backend_vk_event_wait(backend=" << backend << ", event=" << event << ")");
14352 ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14353 vk_event *vkev = (vk_event *)event->context;
14354
14355 vk_context compute_ctx;
14356
14357 if (ctx->compute_ctx.expired()) {
14358 // Initialize new transfer context
14359 compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
14360 ctx->compute_ctx = compute_ctx;
14361 ggml_vk_ctx_begin(ctx->device, compute_ctx);
14362 } else {
14363 compute_ctx = ctx->compute_ctx.lock();
14364 }
14365
14366 ggml_vk_wait_events(compute_ctx, {vkev->event});
14367 ggml_vk_ctx_end(compute_ctx);
14368 ctx->compute_ctx.reset();
14369}
14370
14371// TODO: enable async and synchronize
14372static ggml_backend_i ggml_backend_vk_interface = {
14373 /* .get_name = */ ggml_backend_vk_name,
14374 /* .free = */ ggml_backend_vk_free,
14375 /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
14376 /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
14377 /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
14378 /* .synchronize = */ ggml_backend_vk_synchronize,
14379 /* .graph_plan_create = */ NULL,
14380 /* .graph_plan_free = */ NULL,
14381 /* .graph_plan_update = */ NULL,
14382 /* .graph_plan_compute = */ NULL,
14383 /* .graph_compute = */ ggml_backend_vk_graph_compute,
14384 /* .event_record = */ ggml_backend_vk_event_record,
14385 /* .event_wait = */ ggml_backend_vk_event_wait,
14386 /* .graph_optimize = */ ggml_vk_graph_optimize,
14387};
14388
14389static ggml_guid_t ggml_backend_vk_guid() {
14390 static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
14391 return &guid;
14392}
14393
14394ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
14395 VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
14396
14397 ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
14398 ggml_vk_init(ctx, dev_num);
14399
14400 ggml_backend_t vk_backend = new ggml_backend {
14401 /* .guid = */ ggml_backend_vk_guid(),
14402 /* .iface = */ ggml_backend_vk_interface,
14403 /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
14404 /* .context = */ ctx,
14405 };
14406
14407 if (!ctx->device->support_async) {
14408 vk_backend->iface.get_tensor_async = nullptr;
14409 }
14410
14411 return vk_backend;
14412}
14413
14414bool ggml_backend_is_vk(ggml_backend_t backend) {
14415 return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
14416}
14417
14418int ggml_backend_vk_get_device_count() {
14419 return ggml_vk_get_device_count();
14420}
14421
14422void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
14423 GGML_ASSERT(device < (int) vk_instance.device_indices.size());
14424 int dev_idx = vk_instance.device_indices[device];
14425 ggml_vk_get_device_description(dev_idx, description, description_size);
14426}
14427
14428void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
14429 GGML_ASSERT(device < (int) vk_instance.device_indices.size());
14430 GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
14431
14432 vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
14433 vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
14434 vk::PhysicalDeviceMemoryProperties2 memprops = {};
14435 const bool membudget_supported = vk_instance.device_supports_membudget[device];
14436 const bool is_integrated_gpu = vkdev.getProperties().deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
14437
14438 if (membudget_supported) {
14439 memprops.pNext = &budgetprops;
14440 }
14441 vkdev.getMemoryProperties2(&memprops);
14442
14443 *total = 0;
14444 *free = 0;
14445
14446 for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) {
14447 const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i];
14448
14449 if (is_integrated_gpu || (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal)) {
14450 *total += heap.size;
14451
14452 if (membudget_supported && i < budgetprops.heapUsage.size()) {
14453 *free += budgetprops.heapBudget[i] - budgetprops.heapUsage[i];
14454 } else {
14455 *free += heap.size;
14456 }
14457 }
14458 }
14459}
14460
14461static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) {
14462 GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
14463
14464 vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
14465
14466 vk::PhysicalDeviceProperties2 props = {};
14467 device.getProperties2(&props);
14468
14469 return props.properties.deviceType;
14470}
14471
14472static std::string ggml_backend_vk_get_device_pci_id(int device_idx) {
14473 GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size());
14474
14475 vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]];
14476
14477 const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
14478
14479 bool ext_support = false;
14480
14481 for (const auto& properties : ext_props) {
14482 if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) {
14483 ext_support = true;
14484 break;
14485 }
14486 }
14487
14488 if (!ext_support) {
14489 return "";
14490 }
14491
14492 vk::PhysicalDeviceProperties2 props = {};
14493 vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {};
14494
14495 props.pNext = &pci_bus_info;
14496
14497 device.getProperties2(&props);
14498
14499 const uint32_t pci_domain = pci_bus_info.pciDomain;
14500 const uint32_t pci_bus = pci_bus_info.pciBus;
14501 const uint32_t pci_device = pci_bus_info.pciDevice;
14502 const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning
14503
14504 char pci_bus_id[16] = {};
14505 snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function);
14506
14507 return std::string(pci_bus_id);
14508}
14509
14510//////////////////////////
14511
14512struct ggml_backend_vk_device_context {
14513 size_t device;
14514 std::string name;
14515 std::string description;
14516 bool is_integrated_gpu;
14517 std::string pci_bus_id;
14518 int op_offload_min_batch_size;
14519};
14520
14521static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
14522 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14523 return ctx->name.c_str();
14524}
14525
14526static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
14527 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14528 return ctx->description.c_str();
14529}
14530
14531static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
14532 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
14533 ggml_backend_vk_get_device_memory(ctx->device, free, total);
14534}
14535
14536static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
14537 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14538 return ggml_backend_vk_buffer_type(ctx->device);
14539}
14540
14541static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
14542 UNUSED(dev);
14543 return ggml_backend_vk_host_buffer_type();
14544}
14545
14546static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
14547 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14548
14549 return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU;
14550}
14551
14552static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
14553 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14554
14555 props->name = ggml_backend_vk_device_get_name(dev);
14556 props->description = ggml_backend_vk_device_get_description(dev);
14557 props->type = ggml_backend_vk_device_get_type(dev);
14558 props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
14559 ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
14560 props->caps = {
14561 /* .async = */ true,
14562 /* .host_buffer = */ true,
14563 /* .buffer_from_host_ptr = */ false,
14564 /* .events = */ true,
14565 };
14566}
14567
14568static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
14569 UNUSED(params);
14570 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14571 return ggml_backend_vk_init(ctx->device);
14572}
14573
14574static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
14575 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14576 const vk_device& device = ggml_vk_get_device(ctx->device);
14577
14578 const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
14579 device->shader_int64 && device->buffer_device_address;
14580
14581 auto const & tensor_size_supported = [&](size_t tensor_size) {
14582 if (tensor_size > device->max_buffer_size) {
14583 return false;
14584 }
14585 // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
14586 // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
14587 if (!uses_bda && !device->shader_64b_indexing) {
14588 if (tensor_size > device->properties.limits.maxStorageBufferRange) {
14589 return false;
14590 }
14591 }
14592 return true;
14593 };
14594 // reject any tensors larger than the max buffer size
14595 for (int i = 0; i < GGML_MAX_SRC; i++) {
14596 if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
14597 return false;
14598 }
14599 }
14600 if (!tensor_size_supported(ggml_nbytes(op))) {
14601 return false;
14602 }
14603
14604 switch (op->op) {
14605 case GGML_OP_UNARY:
14606 switch (ggml_get_unary_op(op)) {
14607 case GGML_UNARY_OP_EXP:
14608 case GGML_UNARY_OP_GELU:
14609 case GGML_UNARY_OP_GELU_ERF:
14610 case GGML_UNARY_OP_GELU_QUICK:
14611 case GGML_UNARY_OP_SILU:
14612 case GGML_UNARY_OP_RELU:
14613 case GGML_UNARY_OP_XIELU:
14614 case GGML_UNARY_OP_NEG:
14615 case GGML_UNARY_OP_TANH:
14616 case GGML_UNARY_OP_SIGMOID:
14617 case GGML_UNARY_OP_HARDSIGMOID:
14618 case GGML_UNARY_OP_HARDSWISH:
14619 case GGML_UNARY_OP_ABS:
14620 case GGML_UNARY_OP_SOFTPLUS:
14621 case GGML_UNARY_OP_STEP:
14622 case GGML_UNARY_OP_ROUND:
14623 case GGML_UNARY_OP_CEIL:
14624 case GGML_UNARY_OP_FLOOR:
14625 case GGML_UNARY_OP_TRUNC:
14626 return ggml_is_contiguous(op->src[0]) &&
14627 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14628 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
14629 (op->src[0]->type == op->type);
14630 default:
14631 return false;
14632 }
14633 case GGML_OP_GLU:
14634 switch (ggml_get_glu_op(op)) {
14635 case GGML_GLU_OP_GEGLU:
14636 case GGML_GLU_OP_REGLU:
14637 case GGML_GLU_OP_SWIGLU:
14638 case GGML_GLU_OP_SWIGLU_OAI:
14639 case GGML_GLU_OP_GEGLU_ERF:
14640 case GGML_GLU_OP_GEGLU_QUICK:
14641 return ggml_is_contiguous(op->src[0]) &&
14642 (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14643 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
14644 (op->src[0]->type == op->type);
14645 default:
14646 return false;
14647 }
14648 case GGML_OP_MUL_MAT:
14649 case GGML_OP_MUL_MAT_ID:
14650 {
14651 ggml_type src0_type = op->src[0]->type;
14652 if (op->op == GGML_OP_MUL_MAT_ID) {
14653 if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
14654 // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
14655 return false;
14656 }
14657 }
14658 switch (src0_type) {
14659 case GGML_TYPE_F32:
14660 case GGML_TYPE_F16:
14661 case GGML_TYPE_BF16:
14662 case GGML_TYPE_Q4_0:
14663 case GGML_TYPE_Q4_1:
14664 case GGML_TYPE_Q5_0:
14665 case GGML_TYPE_Q5_1:
14666 case GGML_TYPE_Q8_0:
14667 case GGML_TYPE_Q2_K:
14668 case GGML_TYPE_Q3_K:
14669 case GGML_TYPE_Q4_K:
14670 case GGML_TYPE_Q5_K:
14671 case GGML_TYPE_Q6_K:
14672 case GGML_TYPE_IQ1_S:
14673 case GGML_TYPE_IQ1_M:
14674 case GGML_TYPE_IQ2_XXS:
14675 case GGML_TYPE_IQ2_XS:
14676 case GGML_TYPE_IQ2_S:
14677 case GGML_TYPE_IQ3_XXS:
14678 case GGML_TYPE_IQ3_S:
14679 case GGML_TYPE_IQ4_XS:
14680 case GGML_TYPE_IQ4_NL:
14681 case GGML_TYPE_MXFP4:
14682 break;
14683 default:
14684 return false;
14685 }
14686 struct ggml_tensor * a;
14687 struct ggml_tensor * b;
14688 if (op->op == GGML_OP_MUL_MAT) {
14689 a = op->src[0];
14690 b = op->src[1];
14691 } else {
14692 a = op->src[2];
14693 b = op->src[1];
14694 }
14695 if (a->ne[3] != b->ne[3]) {
14696 return false;
14697 }
14698 if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
14699 !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
14700 return false;
14701 }
14702 if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
14703 // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
14704 // So don't support this combination for now.
14705 return false;
14706 }
14707
14708 return true;
14709 }
14710 case GGML_OP_FLASH_ATTN_EXT:
14711 {
14712 bool coopmat2 = device->coopmat2;
14713 uint32_t HSK = op->src[1]->ne[0];
14714 uint32_t HSV = op->src[2]->ne[0];
14715 if ((HSK % 8) != 0 || (HSV % 8) != 0) {
14716 return false;
14717 }
14718 if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) {
14719 return false;
14720 }
14721 if (op->src[0]->type != GGML_TYPE_F32) {
14722 return false;
14723 }
14724 if (op->type != GGML_TYPE_F32) {
14725 return false;
14726 }
14727 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
14728 return false;
14729 }
14730 // It's straightforward to support different K/V dequant, but would
14731 // significantly increase the number of pipelines
14732 if (op->src[1]->type != op->src[2]->type) {
14733 return false;
14734 }
14735 switch (op->src[1]->type) {
14736 case GGML_TYPE_F16:
14737 case GGML_TYPE_F32:
14738 case GGML_TYPE_Q4_0:
14739 case GGML_TYPE_Q8_0:
14740 // supported in scalar and coopmat2 paths
14741 break;
14742 case GGML_TYPE_Q4_1:
14743 case GGML_TYPE_Q5_0:
14744 case GGML_TYPE_Q5_1:
14745 // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
14746 //case GGML_TYPE_Q2_K:
14747 //case GGML_TYPE_Q3_K:
14748 //case GGML_TYPE_Q4_K:
14749 //case GGML_TYPE_Q5_K:
14750 //case GGML_TYPE_Q6_K:
14751 //case GGML_TYPE_IQ1_S:
14752 //case GGML_TYPE_IQ1_M:
14753 //case GGML_TYPE_IQ2_XXS:
14754 //case GGML_TYPE_IQ2_XS:
14755 //case GGML_TYPE_IQ2_S:
14756 //case GGML_TYPE_IQ3_XXS:
14757 //case GGML_TYPE_IQ3_S:
14758 //case GGML_TYPE_IQ4_XS:
14759 case GGML_TYPE_IQ4_NL:
14760 // currently supported only in coopmat2 path
14761 if (!coopmat2) {
14762 return false;
14763 }
14764 break;
14765 default:
14766 return false;
14767 }
14768 if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
14769 // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
14770 return false;
14771 }
14772 return true;
14773 }
14774 case GGML_OP_GET_ROWS:
14775 {
14776 switch (op->src[0]->type) {
14777 case GGML_TYPE_F32:
14778 case GGML_TYPE_F16:
14779 case GGML_TYPE_BF16:
14780 case GGML_TYPE_Q4_0:
14781 case GGML_TYPE_Q4_1:
14782 case GGML_TYPE_Q5_0:
14783 case GGML_TYPE_Q5_1:
14784 case GGML_TYPE_Q8_0:
14785 case GGML_TYPE_Q2_K:
14786 case GGML_TYPE_Q3_K:
14787 case GGML_TYPE_Q4_K:
14788 case GGML_TYPE_Q5_K:
14789 case GGML_TYPE_Q6_K:
14790 case GGML_TYPE_IQ1_S:
14791 case GGML_TYPE_IQ1_M:
14792 case GGML_TYPE_IQ2_XXS:
14793 case GGML_TYPE_IQ2_XS:
14794 case GGML_TYPE_IQ2_S:
14795 case GGML_TYPE_IQ3_XXS:
14796 case GGML_TYPE_IQ3_S:
14797 case GGML_TYPE_IQ4_XS:
14798 case GGML_TYPE_IQ4_NL:
14799 case GGML_TYPE_MXFP4:
14800 case GGML_TYPE_I32:
14801 return true;
14802 default:
14803 return false;
14804 }
14805 }
14806 case GGML_OP_SET_ROWS:
14807 {
14808 switch (op->type) {
14809 case GGML_TYPE_F32:
14810 case GGML_TYPE_F16:
14811 case GGML_TYPE_BF16:
14812 case GGML_TYPE_Q4_0:
14813 case GGML_TYPE_Q4_1:
14814 case GGML_TYPE_Q5_0:
14815 case GGML_TYPE_Q5_1:
14816 case GGML_TYPE_Q8_0:
14817 case GGML_TYPE_IQ4_NL:
14818 return true;
14819 default:
14820 return false;
14821 }
14822 }
14823 case GGML_OP_CONT:
14824 case GGML_OP_CPY:
14825 case GGML_OP_DUP:
14826 {
14827 ggml_type src0_type = op->src[0]->type;
14828 ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
14829
14830 if (src0_type == GGML_TYPE_F32) {
14831 switch (src1_type) {
14832 case GGML_TYPE_F32:
14833 case GGML_TYPE_F16:
14834 case GGML_TYPE_BF16:
14835 case GGML_TYPE_Q4_0:
14836 case GGML_TYPE_Q4_1:
14837 case GGML_TYPE_Q5_0:
14838 case GGML_TYPE_Q5_1:
14839 case GGML_TYPE_Q8_0:
14840 case GGML_TYPE_IQ4_NL:
14841 return true;
14842 default:
14843 break;
14844 }
14845 }
14846 if (src1_type == GGML_TYPE_F32) {
14847 switch (src0_type) {
14848 case GGML_TYPE_F16:
14849 case GGML_TYPE_Q4_0:
14850 case GGML_TYPE_Q4_1:
14851 case GGML_TYPE_Q5_0:
14852 case GGML_TYPE_Q5_1:
14853 case GGML_TYPE_Q8_0:
14854 case GGML_TYPE_IQ4_NL:
14855 return true;
14856 default:
14857 break;
14858 }
14859 }
14860
14861 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
14862 return true;
14863 }
14864
14865 if (
14866 (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) ||
14867 (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32)
14868 ) {
14869 return true;
14870 }
14871
14872 // We can handle copying from a type to the same type if it's
14873 // either not quantized or is quantized and contiguous.
14874 // We use f16 or f32 shaders to do the copy,
14875 // so the type/block size must be a multiple of 4.
14876 if (src0_type == src1_type &&
14877 (!ggml_is_quantized(src0_type) || (ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op))) &&
14878 (ggml_type_size(src0_type) % 2) == 0) {
14879 return true;
14880 }
14881 return false;
14882 }
14883 case GGML_OP_REPEAT:
14884 return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
14885 case GGML_OP_REPEAT_BACK:
14886 return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
14887 case GGML_OP_ROPE:
14888 return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
14889 case GGML_OP_ROPE_BACK:
14890 case GGML_OP_NONE:
14891 case GGML_OP_RESHAPE:
14892 case GGML_OP_VIEW:
14893 case GGML_OP_PERMUTE:
14894 case GGML_OP_TRANSPOSE:
14895 case GGML_OP_RMS_NORM:
14896 return true;
14897 case GGML_OP_NORM:
14898 case GGML_OP_GROUP_NORM:
14899 case GGML_OP_L2_NORM:
14900 return ggml_is_contiguous(op->src[0]);
14901 case GGML_OP_ADD:
14902 case GGML_OP_SUB:
14903 case GGML_OP_MUL:
14904 case GGML_OP_DIV:
14905 return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14906 (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
14907 (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
14908 case GGML_OP_ADD_ID:
14909 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
14910 op->type == GGML_TYPE_F32;
14911 case GGML_OP_SILU_BACK:
14912 case GGML_OP_RMS_NORM_BACK:
14913 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14914 case GGML_OP_SQR:
14915 case GGML_OP_SQRT:
14916 case GGML_OP_SIN:
14917 case GGML_OP_COS:
14918 case GGML_OP_CLAMP:
14919 return op->src[0]->type == GGML_TYPE_F32;
14920 case GGML_OP_LEAKY_RELU:
14921 case GGML_OP_OPT_STEP_ADAMW:
14922 case GGML_OP_OPT_STEP_SGD:
14923 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14924 case GGML_OP_LOG:
14925 case GGML_OP_TRI:
14926 case GGML_OP_DIAG:
14927 return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14928 op->type == op->src[0]->type;
14929 case GGML_OP_ARGSORT:
14930 {
14931 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
14932 return false;
14933 }
14934 // pipeline_argsort_large_f32 requires vulkan memory model.
14935 if (device->vulkan_memory_model) {
14936 return true;
14937 } else {
14938 return op->ne[0] <= (1 << device->max_workgroup_size_log2);
14939 }
14940 }
14941 case GGML_OP_TOP_K:
14942 {
14943 if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
14944 return false;
14945 }
14946 // We could potentially support larger, using argsort to sort the
14947 // whole thing. Not clear if this is needed.
14948 uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
14949 if (min_pipeline >= num_topk_pipelines ||
14950 !device->pipeline_topk_f32[min_pipeline]) {
14951 return false;
14952 }
14953 }
14954 return true;
14955 case GGML_OP_UPSCALE:
14956 if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) {
14957 if ((op->op_params[0] & 0xFF) != GGML_SCALE_MODE_BILINEAR) {
14958 return false;
14959 }
14960 }
14961 return op->src[0]->type == GGML_TYPE_F32;
14962 case GGML_OP_ACC:
14963 return op->src[0]->type == GGML_TYPE_F32;
14964 case GGML_OP_CONCAT:
14965 return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
14966 case GGML_OP_ADD1:
14967 return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
14968 || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
14969 || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
14970 case GGML_OP_ARANGE:
14971 case GGML_OP_FILL:
14972 return op->type == GGML_TYPE_F32;
14973 case GGML_OP_SCALE:
14974 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14975 case GGML_OP_PAD:
14976 case GGML_OP_ROLL:
14977 return op->src[0]->type == GGML_TYPE_F32;
14978 case GGML_OP_DIAG_MASK_INF:
14979 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
14980 case GGML_OP_SOFT_MAX:
14981 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
14982 && (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
14983 case GGML_OP_SOFT_MAX_BACK:
14984 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
14985 && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
14986 case GGML_OP_SUM:
14987 case GGML_OP_SUM_ROWS:
14988 case GGML_OP_MEAN:
14989 return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
14990 case GGML_OP_CUMSUM:
14991 {
14992 if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
14993 return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
14994 }
14995 return false;
14996 }
14997 case GGML_OP_SOLVE_TRI:
14998 {
14999 if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
15000 return false;
15001 }
15002 const uint32_t N = op->src[0]->ne[0];
15003 const uint32_t K = op->src[1]->ne[0];
15004 // K dimension limited to workgroup size
15005 if (K > 1u << device->max_workgroup_size_log2) {
15006 return false;
15007 }
15008 const uint32_t batch_N = device->properties.limits.maxComputeSharedMemorySize / ((N + K) * sizeof(float));
15009
15010 if (batch_N == 0) {
15011 return false;
15012 }
15013 return true;
15014 }
15015 case GGML_OP_ARGMAX:
15016 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
15017 case GGML_OP_COUNT_EQUAL:
15018 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
15019 && ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
15020 case GGML_OP_IM2COL:
15021 return ggml_is_contiguous(op->src[1])
15022 && op->src[1]->type == GGML_TYPE_F32
15023 && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
15024 case GGML_OP_IM2COL_3D:
15025 return op->src[1]->type == GGML_TYPE_F32
15026 && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
15027 case GGML_OP_TIMESTEP_EMBEDDING:
15028 return op->src[0]->type == GGML_TYPE_F32;
15029 case GGML_OP_CONV_2D_DW:
15030 return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
15031 && op->src[1]->type == GGML_TYPE_F32;
15032 case GGML_OP_POOL_2D:
15033 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
15034 case GGML_OP_RWKV_WKV6:
15035 case GGML_OP_RWKV_WKV7:
15036 return true; // all inputs are contiguous, see ggml.c
15037 case GGML_OP_SSM_SCAN:
15038 {
15039 for (int i = 0; i < 6; i++) {
15040 if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
15041 return false;
15042 }
15043 }
15044 if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
15045 return false;
15046 }
15047 if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
15048 return false;
15049 }
15050
15051 const uint32_t d_state = op->src[0]->ne[0];
15052 const uint32_t head_dim = op->src[0]->ne[1];
15053
15054 bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
15055 if (!is_mamba2) {
15056 return false;
15057 }
15058
15059 if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
15060 return false;
15061 }
15062
15063 size_t shmem_size = d_state * sizeof(float);
15064
15065 if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) {
15066 return false;
15067 }
15068
15069 if (!device->subgroup_basic) {
15070 return false;
15071 }
15072
15073 return true;
15074 }
15075 case GGML_OP_SSM_CONV:
15076 return op->src[0]->type == GGML_TYPE_F32;
15077 case GGML_OP_CONV_TRANSPOSE_1D:
15078 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
15079 case GGML_OP_CONV_2D:
15080 case GGML_OP_CONV_TRANSPOSE_2D:
15081 {
15082 // Channel-contiguous format is not supported yet.
15083 return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
15084 op->src[1]->type == GGML_TYPE_F32 &&
15085 op->type == GGML_TYPE_F32 &&
15086 ggml_is_contiguous(op->src[0]) &&
15087 ggml_is_contiguous(op->src[1]) &&
15088 ggml_is_contiguous(op));
15089 }
15090 default:
15091 return false;
15092 }
15093
15094 UNUSED(dev);
15095}
15096
15097static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
15098 if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
15099 return false;
15100 }
15101
15102 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15103 ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
15104
15105 return buft_ctx->device->idx == ctx->device;
15106}
15107
15108static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
15109 ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
15110
15111 return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
15112 (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
15113}
15114
15115static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
15116 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15117 auto device = ggml_vk_get_device(ctx->device);
15118
15119 vk_event *vkev = new vk_event;
15120 if (!vkev) {
15121 return nullptr;
15122 }
15123
15124 // The event/fence is expected to initially be in the signaled state.
15125 vkev->event = device->device.createEvent({});
15126 vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
15127 device->device.setEvent(vkev->event);
15128
15129 return new ggml_backend_event {
15130 /* .device = */ dev,
15131 /* .context = */ vkev,
15132 };
15133}
15134
15135static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
15136 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15137 auto device = ggml_vk_get_device(ctx->device);
15138
15139 vk_event *vkev = (vk_event *)event->context;
15140
15141 device->device.destroyFence(vkev->fence);
15142 device->device.destroyEvent(vkev->event);
15143 delete vkev;
15144 delete event;
15145}
15146
15147static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
15148 VK_LOG_DEBUG("ggml_backend_vk_device_event_synchronize(backend=" << dev << ", event=" << event << ")");
15149 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15150 auto device = ggml_vk_get_device(ctx->device);
15151 vk_event *vkev = (vk_event *)event->context;
15152
15153 VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
15154}
15155
15156static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
15157 if (!device->external_memory_host) {
15158 return {};
15159 }
15160
15161 uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
15162 if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
15163 return {};
15164 }
15165 if (size & (device->min_imported_host_pointer_alignment - 1)) {
15166 return {};
15167 }
15168
15169 const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
15170
15171 vk_buffer buf {};
15172 try {
15173 buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
15174 } catch (vk::SystemError& e) {
15175 GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
15176 }
15177
15178 return buf;
15179}
15180
15181static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
15182 VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
15183 GGML_UNUSED(max_tensor_size);
15184
15185 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
15186 auto device = ggml_vk_get_device(ctx->device);
15187
15188 vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
15189
15190 if (!buf) {
15191 return {};
15192 }
15193
15194 ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
15195
15196 ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
15197
15198 return ret;
15199}
15200
15201static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
15202 /* .get_name = */ ggml_backend_vk_device_get_name,
15203 /* .get_description = */ ggml_backend_vk_device_get_description,
15204 /* .get_memory = */ ggml_backend_vk_device_get_memory,
15205 /* .get_type = */ ggml_backend_vk_device_get_type,
15206 /* .get_props = */ ggml_backend_vk_device_get_props,
15207 /* .init_backend = */ ggml_backend_vk_device_init,
15208 /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
15209 /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
15210 /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr,
15211 /* .supports_op = */ ggml_backend_vk_device_supports_op,
15212 /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
15213 /* .offload_op = */ ggml_backend_vk_device_offload_op,
15214 /* .event_new = */ ggml_backend_vk_device_event_new,
15215 /* .event_free = */ ggml_backend_vk_device_event_free,
15216 /* .event_synchronize = */ ggml_backend_vk_device_event_synchronize,
15217};
15218
15219static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
15220 UNUSED(reg);
15221 return GGML_VK_NAME;
15222}
15223
15224static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
15225 UNUSED(reg);
15226 return ggml_backend_vk_get_device_count();
15227}
15228
15229static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
15230 static std::vector<ggml_backend_dev_t> devices;
15231
15232 static bool initialized = false;
15233
15234 {
15235 static std::mutex mutex;
15236 std::lock_guard<std::mutex> lock(mutex);
15237 if (!initialized) {
15238 const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;
15239 for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
15240 ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
15241 char desc[256];
15242 ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
15243 ctx->device = i;
15244 ctx->name = GGML_VK_NAME + std::to_string(i);
15245 ctx->description = desc;
15246 ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu;
15247 ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i);
15248 ctx->op_offload_min_batch_size = min_batch_size;
15249 devices.push_back(new ggml_backend_device {
15250 /* .iface = */ ggml_backend_vk_device_i,
15251 /* .reg = */ reg,
15252 /* .context = */ ctx,
15253 });
15254 }
15255 initialized = true;
15256 }
15257 }
15258
15259 GGML_ASSERT(device < devices.size());
15260 return devices[device];
15261}
15262
15263static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
15264 /* .get_name = */ ggml_backend_vk_reg_get_name,
15265 /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
15266 /* .get_device = */ ggml_backend_vk_reg_get_device,
15267 /* .get_proc_address = */ NULL,
15268};
15269
15270ggml_backend_reg_t ggml_backend_vk_reg() {
15271 static ggml_backend_reg reg = {
15272 /* .api_version = */ GGML_BACKEND_API_VERSION,
15273 /* .iface = */ ggml_backend_vk_reg_i,
15274 /* .context = */ nullptr,
15275 };
15276 try {
15277 ggml_vk_instance_init();
15278 return ®
15279 } catch (const vk::SystemError& e) {
15280 VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what());
15281 return nullptr;
15282 } catch (const std::exception &e) {
15283 VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what());
15284 return nullptr;
15285 } catch (...) {
15286 VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init");
15287 return nullptr;
15288 }
15289}
15290
15291// Extension availability
15292static bool ggml_vk_instance_layer_settings_available() {
15293#ifdef GGML_VULKAN_VALIDATE
15294 // Check if validation layer provides the extension
15295 const std::string layer_name = "VK_LAYER_KHRONOS_validation";
15296 for (const auto& layer : vk::enumerateInstanceLayerProperties()) {
15297 if (layer_name == layer.layerName.data()) {
15298 for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) {
15299 if (strcmp("VK_EXT_layer_settings", ext.extensionName.data()) == 0) {
15300 return true;
15301 }
15302 }
15303 }
15304 }
15305
15306 std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_layer_settings not found." << std::endl;
15307#endif
15308 return false;
15309}
15310static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector<vk::ExtensionProperties>& instance_extensions) {
15311#ifdef __APPLE__
15312 // Check for portability enumeration extension for MoltenVK support
15313 for (const auto& properties : instance_extensions) {
15314 if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) {
15315 return true;
15316 }
15317 }
15318 std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl;
15319#endif
15320 return false;
15321
15322 UNUSED(instance_extensions);
15323}
15324
15325// Extension availability
15326static bool ggml_vk_instance_debug_utils_ext_available(
15327 const std::vector<vk::ExtensionProperties> & instance_extensions) {
15328 // Check for portability enumeration extension for MoltenVK support
15329 for (const auto & properties : instance_extensions) {
15330 if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) {
15331 return true;
15332 }
15333 }
15334
15335 std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl;
15336 return false;
15337
15338 UNUSED(instance_extensions);
15339}
15340
15341static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) {
15342 VkPhysicalDeviceFeatures2 device_features2;
15343 device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
15344
15345 VkPhysicalDeviceVulkan11Features vk11_features;
15346 vk11_features.pNext = nullptr;
15347 vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES;
15348 device_features2.pNext = &vk11_features;
15349
15350 vkGetPhysicalDeviceFeatures2(vkdev, &device_features2);
15351
15352 return vk11_features.storageBuffer16BitAccess;
15353}
15354
15355static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
15356 switch (props.vendorID) {
15357 case VK_VENDOR_ID_INTEL:
15358 // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost,
15359 // while some older hardware (ex. Arc A770) has performance regressions
15360 return arch == vk_device_architecture::INTEL_XE2;
15361 case VK_VENDOR_ID_AMD:
15362 if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
15363 // Workaround for AMD proprietary driver reporting support on all GPUs
15364 return arch == vk_device_architecture::AMD_RDNA3;
15365 }
15366 return true;
15367 default:
15368 return true;
15369 }
15370}
15371
15372// checks
15373
15374#ifdef GGML_VULKAN_CHECK_RESULTS
15375static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<const ggml_tensor *>& done, int level = 0) {
15376 if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) {
15377 return;
15378 }
15379 for (int j = 0; j < level; j++) {
15380 std::cerr << " ";
15381 }
15382 std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl;
15383
15384 done.push_back(tensor);
15385
15386 for (int i = 0; i < GGML_MAX_SRC; i++) {
15387 if (tensor->src[i] != nullptr) {
15388 ggml_vk_print_graph_origin(tensor->src[i], done, level + 1);
15389 }
15390 }
15391}
15392
15393static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) {
15394 if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) {
15395 return;
15396 }
15397 i0 = std::max(i0, 5);
15398 i1 = std::max(i1, 5);
15399 i2 = std::max(i2, 0);
15400 i3 = std::max(i3, 0);
15401 fprintf(stderr, " ");
15402 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
15403 fprintf(stderr, "%7d ", idx1);
15404 }
15405 fprintf(stderr, "\n");
15406 for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) {
15407 fprintf(stderr, "%7d: ", idx0);
15408 for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) {
15409 if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) {
15410 float val;
15411 if (tensor->type == GGML_TYPE_F32) {
15412 val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
15413 } else if (tensor->type == GGML_TYPE_F16) {
15414 val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]));
15415 } else if (tensor->type == GGML_TYPE_I32) {
15416 val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]);
15417 } else {
15418 GGML_ABORT("fatal error");
15419 }
15420 fprintf(stderr, "% 7.2f ", val);
15421 } else {
15422 fprintf(stderr, " ");
15423 }
15424 }
15425 fprintf(stderr, "\n");
15426 }
15427}
15428
15429static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) {
15430 void * tensor_data = tensor->data;
15431
15432 const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer);
15433
15434 if (is_gpu) {
15435 const size_t tensor_size = ggml_nbytes(tensor);
15436 tensor_data = malloc(tensor_size);
15437
15438 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
15439
15440 vk_buffer buffer_gpu = buf_ctx->dev_buffer;
15441 ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size);
15442 }
15443
15444 std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl;
15445 std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl;
15446 if (tensor->src[0] != nullptr) {
15447 std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl;
15448 }
15449 if (tensor->src[1] != nullptr) {
15450 std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl;
15451 }
15452 std::cerr << std::endl << "Result:" << std::endl;
15453 ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
15454 std::cerr << std::endl;
15455 std::vector<const ggml_tensor *> done;
15456 ggml_vk_print_graph_origin(tensor, done);
15457
15458 if (is_gpu) {
15459 free(tensor_data);
15460 }
15461}
15462
15463void * comp_result;
15464size_t comp_size;
15465size_t comp_nb[GGML_MAX_DIMS];
15466size_t check_counter = 0;
15467static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
15468 ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
15469 if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
15470 return;
15471 }
15472
15473 check_counter++;
15474 if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
15475 return;
15476 }
15477
15478 VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")");
15479
15480 struct ggml_init_params iparams = {
15481 /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
15482 /*.mem_buffer =*/ NULL,
15483 /*.no_alloc =*/ false,
15484 };
15485
15486 struct ggml_context * ggml_ctx = ggml_init(iparams);
15487
15488 std::array<struct ggml_tensor *, GGML_MAX_SRC> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
15489 const char * srci_name[GGML_MAX_SRC] = {"src0", "src1", "src2", "src3", "src4", "src5", "src6", "src7", "src8", "src9"};
15490
15491 std::map<ggml_tensor *, ggml_tensor *> cloned_tensors;
15492 std::vector<void *> cloned_mallocs;
15493
15494 struct ggml_tensor * tensor_clone = nullptr;
15495
15496 for (int f = 0; f < ctx->num_additional_fused_ops + 1; ++f) {
15497 tensor = cgraph->nodes[tensor_idx + f];
15498 for (int i = 0; i < GGML_MAX_SRC; i++) {
15499 ggml_tensor * srci = tensor->src[i];
15500 if (srci == nullptr) {
15501 continue;
15502 }
15503 // If a src tensor has been cloned, use that one
15504 auto it = cloned_tensors.find(srci);
15505 if (it != cloned_tensors.end()) {
15506 src_clone[i] = it->second;
15507 continue;
15508 }
15509 ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci);
15510 size_t srci_size = ggml_nbytes(srci);
15511
15512 src_clone[i] = srci_clone;
15513 void *src_buffer = malloc(srci_size);
15514 cloned_mallocs.push_back(src_buffer);
15515
15516 srci_clone->data = src_buffer;
15517 if (ggml_backend_buffer_is_host(srci->buffer)) {
15518 memcpy(srci_clone->data, srci->data, srci_size);
15519 memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
15520 } else if (ggml_backend_buffer_is_vk(srci->buffer)) {
15521 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context;
15522 vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
15523 uint64_t offset = vk_tensor_offset(srci) + srci->view_offs;
15524 if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) {
15525 for (int i3 = 0; i3 < srci->ne[3]; i3++) {
15526 for (int i2 = 0; i2 < srci->ne[2]; i2++) {
15527 const int idx = i3*srci->ne[2] + i2;
15528 ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]);
15529 }
15530 }
15531
15532 srci_clone->nb[0] = srci->nb[0];
15533 srci_clone->nb[1] = srci->nb[1];
15534 for (int i = 2; i < GGML_MAX_DIMS; i++) {
15535 srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1];
15536 }
15537 } else {
15538 if (offset + srci_size >= buffer_gpu->size) {
15539 srci_size = buffer_gpu->size - offset;
15540 }
15541 ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size);
15542 memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS);
15543 }
15544 } else {
15545 GGML_ABORT("fatal error");
15546 }
15547
15548 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
15549 ggml_vk_print_tensor(srci, srci_name[i]);
15550 }
15551 }
15552
15553 if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
15554 const float * params = (const float *)tensor->op_params;
15555 tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
15556 if (src_clone[4]) {
15557 ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]);
15558 }
15559 } else if (tensor->op == GGML_OP_MUL_MAT) {
15560 tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
15561 } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
15562 tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
15563 } else if (tensor->op == GGML_OP_SUB) {
15564 tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
15565 } else if (tensor->op == GGML_OP_MUL) {
15566 tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
15567 } else if (tensor->op == GGML_OP_DIV) {
15568 tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
15569 } else if (tensor->op == GGML_OP_CONCAT) {
15570 tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
15571 } else if (tensor->op == GGML_OP_UPSCALE) {
15572 tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]);
15573 } else if (tensor->op == GGML_OP_SCALE) {
15574 const float * params = (const float *)tensor->op_params;
15575 tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]);
15576 } else if (tensor->op == GGML_OP_ADD1) {
15577 tensor_clone = ggml_add1(ggml_ctx, src_clone[0], src_clone[1]);
15578 } else if (tensor->op == GGML_OP_ARANGE) {
15579 const float start = ggml_get_op_params_f32(tensor, 0);
15580 const float stop = ggml_get_op_params_f32(tensor, 1);
15581 const float step = ggml_get_op_params_f32(tensor, 2);
15582 tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
15583 } else if (tensor->op == GGML_OP_FILL) {
15584 const float value = ggml_get_op_params_f32(tensor, 0);
15585 tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
15586 } else if (tensor->op == GGML_OP_SQR) {
15587 tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
15588 } else if (tensor->op == GGML_OP_SQRT) {
15589 tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]);
15590 } else if (tensor->op == GGML_OP_SIN) {
15591 tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
15592 } else if (tensor->op == GGML_OP_COS) {
15593 tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
15594 } else if (tensor->op == GGML_OP_LOG) {
15595 tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
15596 } else if (tensor->op == GGML_OP_TRI) {
15597 tensor_clone = ggml_tri(ggml_ctx, src_clone[0], (ggml_tri_type)ggml_get_op_params_i32(tensor, 0));
15598 } else if (tensor->op == GGML_OP_DIAG) {
15599 tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
15600 } else if (tensor->op == GGML_OP_CLAMP) {
15601 const float * params = (const float *)tensor->op_params;
15602 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
15603 } else if (tensor->op == GGML_OP_PAD) {
15604 tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
15605 tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
15606 } else if (tensor->op == GGML_OP_REPEAT) {
15607 tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
15608 } else if (tensor->op == GGML_OP_REPEAT_BACK) {
15609 tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor);
15610 } else if (tensor->op == GGML_OP_ADD) {
15611 tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
15612 } else if (tensor->op == GGML_OP_ACC) {
15613 tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
15614 } else if (tensor->op == GGML_OP_NORM) {
15615 tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
15616 } else if (tensor->op == GGML_OP_GROUP_NORM) {
15617 const float * float_params = (const float *)tensor->op_params;
15618 tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
15619 } else if (tensor->op == GGML_OP_RMS_NORM) {
15620 tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
15621 } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
15622 const float eps = ((float *) tensor->op_params)[0];
15623 tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
15624 } else if (tensor->op == GGML_OP_SILU_BACK) {
15625 tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
15626 } else if (tensor->op == GGML_OP_L2_NORM) {
15627 const float eps = ((float *) tensor->op_params)[0];
15628 tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
15629 } else if (tensor->op == GGML_OP_SOFT_MAX) {
15630 if (tensor->src[1] != nullptr) {
15631 const float * params = (const float *)tensor->op_params;
15632 tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
15633 } else {
15634 tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
15635 }
15636 } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
15637 tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
15638 } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
15639 tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
15640 } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
15641 const int n_dims = ((int32_t *) tensor->op_params)[1];
15642 const int mode = ((int32_t *) tensor->op_params)[2];
15643 //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
15644 const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4];
15645 const float freq_base = ((float *) tensor->op_params)[5];
15646 const float freq_scale = ((float *) tensor->op_params)[6];
15647 const float ext_factor = ((float *) tensor->op_params)[7];
15648 const float attn_factor = ((float *) tensor->op_params)[8];
15649 const float beta_fast = ((float *) tensor->op_params)[9];
15650 const float beta_slow = ((float *) tensor->op_params)[10];
15651 if (mode & GGML_ROPE_TYPE_MROPE) {
15652 int32_t *sections = ((int32_t *) tensor->op_params) + 11;
15653 if (tensor->op == GGML_OP_ROPE) {
15654 tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15655 } else {
15656 tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15657 }
15658 } else {
15659 if (tensor->op == GGML_OP_ROPE) {
15660 tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15661 } else {
15662 tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
15663 }
15664 }
15665 } else if (tensor->op == GGML_OP_UNARY) {
15666 switch (ggml_get_unary_op(tensor)) {
15667 case GGML_UNARY_OP_EXP:
15668 tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
15669 break;
15670 case GGML_UNARY_OP_SILU:
15671 tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
15672 break;
15673 case GGML_UNARY_OP_GELU:
15674 tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]);
15675 break;
15676 case GGML_UNARY_OP_GELU_ERF:
15677 tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]);
15678 break;
15679 case GGML_UNARY_OP_GELU_QUICK:
15680 tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]);
15681 break;
15682 case GGML_UNARY_OP_RELU:
15683 tensor_clone = ggml_relu(ggml_ctx, src_clone[0]);
15684 break;
15685 case GGML_UNARY_OP_XIELU:
15686 tensor_clone = ggml_xielu(ggml_ctx, src_clone[0], 0, 0, 0, 0);
15687 ggml_set_op_params_f32(tensor_clone, 1, ggml_get_op_params_f32(tensor, 1));
15688 ggml_set_op_params_f32(tensor_clone, 2, ggml_get_op_params_f32(tensor, 2));
15689 ggml_set_op_params_f32(tensor_clone, 3, ggml_get_op_params_f32(tensor, 3));
15690 ggml_set_op_params_f32(tensor_clone, 4, ggml_get_op_params_f32(tensor, 4));
15691 break;
15692 case GGML_UNARY_OP_NEG:
15693 tensor_clone = ggml_neg(ggml_ctx, src_clone[0]);
15694 break;
15695 case GGML_UNARY_OP_TANH:
15696 tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
15697 break;
15698 case GGML_UNARY_OP_SIGMOID:
15699 tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
15700 break;
15701 case GGML_UNARY_OP_HARDSIGMOID:
15702 tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]);
15703 break;
15704 case GGML_UNARY_OP_HARDSWISH:
15705 tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]);
15706 break;
15707 case GGML_UNARY_OP_ABS:
15708 tensor_clone = ggml_abs(ggml_ctx, src_clone[0]);
15709 break;
15710 case GGML_UNARY_OP_SOFTPLUS:
15711 tensor_clone = ggml_softplus(ggml_ctx, src_clone[0]);
15712 break;
15713 case GGML_UNARY_OP_STEP:
15714 tensor_clone = ggml_step(ggml_ctx, src_clone[0]);
15715 break;
15716 case GGML_UNARY_OP_ROUND:
15717 tensor_clone = ggml_round(ggml_ctx, src_clone[0]);
15718 break;
15719 case GGML_UNARY_OP_CEIL:
15720 tensor_clone = ggml_ceil(ggml_ctx, src_clone[0]);
15721 break;
15722 case GGML_UNARY_OP_FLOOR:
15723 tensor_clone = ggml_floor(ggml_ctx, src_clone[0]);
15724 break;
15725 case GGML_UNARY_OP_TRUNC:
15726 tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
15727 break;
15728 default:
15729 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
15730 GGML_ABORT("fatal error");
15731 }
15732 } else if (tensor->op == GGML_OP_GLU) {
15733 if (src_clone[1] == nullptr) {
15734 tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
15735 } else {
15736 tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
15737 }
15738 ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2));
15739 ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3));
15740 } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
15741 if (tensor->src[1] == nullptr) {
15742 tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
15743 tensor_clone->type = tensor->type;
15744 } else {
15745 tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]);
15746 }
15747 } else if (tensor->op == GGML_OP_CONT) {
15748 tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
15749 } else if (tensor->op == GGML_OP_RESHAPE) {
15750 tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
15751 } else if (tensor->op == GGML_OP_VIEW) {
15752 tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]);
15753 } else if (tensor->op == GGML_OP_PERMUTE) {
15754 int32_t * params = (int32_t *)tensor->op_params;
15755 tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]);
15756 } else if (tensor->op == GGML_OP_TRANSPOSE) {
15757 tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]);
15758 } else if (tensor->op == GGML_OP_GET_ROWS) {
15759 tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
15760 } else if (tensor->op == GGML_OP_ARGSORT) {
15761 tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
15762 } else if (tensor->op == GGML_OP_TOP_K) {
15763 tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
15764 } else if (tensor->op == GGML_OP_SUM) {
15765 tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
15766 } else if (tensor->op == GGML_OP_SUM_ROWS) {
15767 tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
15768 } else if (tensor->op == GGML_OP_CUMSUM) {
15769 tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
15770 } else if (tensor->op == GGML_OP_MEAN) {
15771 tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
15772 } else if (tensor->op == GGML_OP_ARGMAX) {
15773 tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
15774 } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
15775 tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
15776 } else if (tensor->op == GGML_OP_SOLVE_TRI) {
15777 tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
15778 } else if (tensor->op == GGML_OP_IM2COL) {
15779 const int32_t s0 = tensor->op_params[0];
15780 const int32_t s1 = tensor->op_params[1];
15781 const int32_t p0 = tensor->op_params[2];
15782 const int32_t p1 = tensor->op_params[3];
15783 const int32_t d0 = tensor->op_params[4];
15784 const int32_t d1 = tensor->op_params[5];
15785
15786 const bool is_2D = tensor->op_params[6] == 1;
15787 tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
15788 } else if (tensor->op == GGML_OP_IM2COL_3D) {
15789 const int32_t s0 = tensor->op_params[0];
15790 const int32_t s1 = tensor->op_params[1];
15791 const int32_t s2 = tensor->op_params[2];
15792 const int32_t p0 = tensor->op_params[3];
15793 const int32_t p1 = tensor->op_params[4];
15794 const int32_t p2 = tensor->op_params[5];
15795 const int32_t d0 = tensor->op_params[6];
15796 const int32_t d1 = tensor->op_params[7];
15797 const int32_t d2 = tensor->op_params[8];
15798 const int32_t IC = tensor->op_params[9];
15799
15800 tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
15801 } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
15802 const int32_t dim = tensor->op_params[0];
15803 const int32_t max_period = tensor->op_params[1];
15804 tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period);
15805 } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){
15806 const int32_t s0 = tensor->op_params[0];
15807 const int32_t p0 = tensor->op_params[1];
15808 const int32_t d0 = tensor->op_params[2];
15809 tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
15810 } else if (tensor->op == GGML_OP_POOL_2D) {
15811 enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
15812 const int32_t k0 = tensor->op_params[1];
15813 const int32_t k1 = tensor->op_params[2];
15814 const int32_t s0 = tensor->op_params[3];
15815 const int32_t s1 = tensor->op_params[4];
15816 const int32_t p0 = tensor->op_params[5];
15817 const int32_t p1 = tensor->op_params[6];
15818
15819 tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1);
15820 } else if (tensor->op == GGML_OP_CONV_2D) {
15821 const int32_t s0 = tensor->op_params[0];
15822 const int32_t s1 = tensor->op_params[1];
15823 const int32_t p0 = tensor->op_params[2];
15824 const int32_t p1 = tensor->op_params[3];
15825 const int32_t d0 = tensor->op_params[4];
15826 const int32_t d1 = tensor->op_params[5];
15827 tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
15828 } else if (tensor->op == GGML_OP_CONV_2D_DW) {
15829 const int32_t s0 = tensor->op_params[0];
15830 const int32_t s1 = tensor->op_params[1];
15831 const int32_t p0 = tensor->op_params[2];
15832 const int32_t p1 = tensor->op_params[3];
15833 const int32_t d0 = tensor->op_params[4];
15834 const int32_t d1 = tensor->op_params[5];
15835 tensor_clone = ggml_conv_2d_dw_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1);
15836 } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) {
15837 const int32_t s = tensor->op_params[0];
15838 tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s);
15839 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
15840 const float * op_params = (const float *)tensor->op_params;
15841 tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false);
15842 } else if (tensor->op == GGML_OP_RWKV_WKV6) {
15843 tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
15844 src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
15845 } else if (tensor->op == GGML_OP_RWKV_WKV7) {
15846 tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
15847 src_clone[4], src_clone[5], src_clone[6]);
15848 } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
15849 src_clone[0]->flags = tensor->src[0]->flags;
15850 tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
15851 src_clone[2], src_clone[3], src_clone[4]);
15852 } else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
15853 src_clone[0]->flags = tensor->src[0]->flags;
15854 tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
15855 src_clone[2]);
15856 } else if (tensor->op == GGML_OP_ADD_ID) {
15857 tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
15858 } else if (tensor->op == GGML_OP_SSM_SCAN) {
15859 tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
15860 src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
15861 } else if (tensor->op == GGML_OP_SSM_CONV) {
15862 tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
15863 } else if (tensor->op == GGML_OP_ROLL) {
15864 const int32_t s0 = tensor->op_params[0];
15865 const int32_t s1 = tensor->op_params[1];
15866 const int32_t s2 = tensor->op_params[2];
15867 const int32_t s3 = tensor->op_params[3];
15868 tensor_clone = ggml_roll(ggml_ctx, src_clone[0], s0, s1, s2, s3);
15869 }
15870 else {
15871 std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
15872 GGML_ABORT("fatal error");
15873 }
15874 cloned_tensors[tensor] = tensor_clone;
15875 }
15876
15877 ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
15878 ggml_build_forward_expand(cgraph_cpu, tensor_clone);
15879
15880 ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
15881
15882 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
15883 ggml_vk_print_tensor(tensor_clone, "tensor_clone");
15884 }
15885
15886 comp_size = ggml_nbytes(tensor_clone);
15887
15888 comp_result = malloc(comp_size);
15889 memcpy(comp_result, tensor_clone->data, comp_size);
15890 memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
15891
15892 for (auto m : cloned_mallocs) {
15893 free(m);
15894 }
15895
15896 ggml_free(ggml_ctx);
15897
15898 VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
15899}
15900
15901static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
15902 ggml_tensor * tensor = cgraph->nodes[tensor_idx + ctx->num_additional_fused_ops];
15903 if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) {
15904 return;
15905 }
15906
15907 if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
15908 return;
15909 }
15910
15911 VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")");
15912
15913 ggml_tensor * src0 = tensor->src[0];
15914 ggml_tensor * src1 = tensor->src[1];
15915 ggml_tensor * src2 = tensor->src[2];
15916 ggml_tensor * src3 = tensor->src[3];
15917
15918 void * tensor_data = tensor->data;
15919
15920 if (ggml_backend_buffer_is_vk(tensor->buffer)) {
15921 size_t tensor_size = ggml_nbytes(tensor);
15922 tensor_data = malloc(tensor_size);
15923
15924 ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
15925
15926 vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
15927 uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs;
15928 if (offset + tensor_size >= buffer_gpu->size) {
15929 tensor_size = buffer_gpu->size - offset;
15930 }
15931
15932 ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size);
15933 }
15934
15935 float first_error_result = -1.0f;
15936 float first_error_correct = -1.0f;
15937 std::array<int, 4> first_error = { -1, -1, -1, -1 };
15938 double avg_err = 0.0;
15939 size_t counter = 0;
15940
15941 for (int i3 = 0; i3 < tensor->ne[3]; i3++) {
15942 for (int i2 = 0; i2 < tensor->ne[2]; i2++) {
15943 for (int i1 = 0; i1 < tensor->ne[1]; i1++) {
15944 for (int i0 = 0; i0 < tensor->ne[0]; i0++) {
15945 const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size;
15946 float correct = 0.0f;
15947 float result = 0.0f;
15948
15949 if (buffer_size_fit) {
15950 if (tensor->type == GGML_TYPE_F32) {
15951 correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
15952 result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
15953 } else if (tensor->type == GGML_TYPE_F16) {
15954 correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
15955 result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
15956 } else if (tensor->type == GGML_TYPE_BF16) {
15957 correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]));
15958 result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]));
15959 } else if (tensor->type == GGML_TYPE_I32) {
15960 correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
15961 result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
15962 } else if (tensor->type == GGML_TYPE_I64) {
15963 correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
15964 result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
15965 } else {
15966 std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
15967 }
15968 } else {
15969 std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl;
15970 GGML_ABORT("fatal error");
15971 }
15972
15973 if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) {
15974 std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl;
15975 std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
15976 if (src0 != nullptr) {
15977 std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
15978 }
15979 if (src1 != nullptr) {
15980 std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
15981 }
15982 if (src2 != nullptr) {
15983 std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
15984 }
15985 if (src3 != nullptr) {
15986 std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
15987 }
15988 std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
15989 std::cerr << std::endl << "Result:" << std::endl;
15990 ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
15991 std::cerr << std::endl << "Correct:" << std::endl;
15992 ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3);
15993 std::cerr << std::endl;
15994 std::vector<const ggml_tensor *> done;
15995 ggml_vk_print_graph_origin(tensor, done);
15996 GGML_ABORT("fatal error");
15997 }
15998 const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f;
15999 if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) {
16000 first_error[0] = i0;
16001 first_error[1] = i1;
16002 first_error[2] = i2;
16003 first_error[3] = i3;
16004 first_error_result = result;
16005 first_error_correct = correct;
16006 }
16007
16008 // Special case, value is infinite, avoid NaN result in avg_err
16009 // NaN also appears in results, if both are nan error is 0
16010 if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) {
16011 avg_err += std::fabs(correct - result) / denom;
16012 }
16013 counter++;
16014 }
16015 }
16016 }
16017 }
16018
16019 avg_err /= counter;
16020
16021 if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
16022 std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
16023 std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
16024 if (src0 != nullptr) {
16025 std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
16026 }
16027 if (src1 != nullptr) {
16028 std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
16029 }
16030 if (src2 != nullptr) {
16031 std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
16032 }
16033 if (src3 != nullptr) {
16034 std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
16035 }
16036 std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
16037 std::cerr << std::endl << "Result:" << std::endl;
16038 ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
16039 std::cerr << std::endl << "Correct:" << std::endl;
16040 ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0);
16041 std::cerr << std::endl;
16042 std::vector<const ggml_tensor *> done;
16043 ggml_vk_print_graph_origin(tensor, done);
16044 }
16045
16046 if (avg_err > 0.5 || std::isnan(avg_err)) {
16047 std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
16048 std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
16049 if (src0 != nullptr) {
16050 std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl;
16051 }
16052 if (src1 != nullptr) {
16053 std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl;
16054 }
16055 if (src2 != nullptr) {
16056 std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
16057 }
16058 if (src3 != nullptr) {
16059 std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
16060 }
16061 std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
16062 std::cerr << std::endl << "Result:" << std::endl;
16063 ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);
16064 std::cerr << std::endl << "Correct:" << std::endl;
16065 ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]);
16066 std::cerr << std::endl;
16067 std::vector<const ggml_tensor *> done;
16068 ggml_vk_print_graph_origin(tensor, done);
16069 GGML_ABORT("fatal error");
16070 } else {
16071 std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl;
16072 }
16073
16074 free(comp_result);
16075 comp_result = nullptr;
16076 comp_size = 0;
16077
16078 if (ggml_backend_buffer_is_vk(tensor->buffer)) {
16079 free(tensor_data);
16080 }
16081
16082 VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
16083}
16084#endif
16085
16086GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)